Files
Machine-Learning-Collection/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
Aladdin Persson e4659fe56a huggingface update
2023-03-18 09:51:16 +01:00

464 lines
93 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "ec1aae37",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
"2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
"2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
]
}
],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
"\n",
"import numpy as np\n",
"import torch\n",
"\n",
"import datasets \n",
"import pytorch_lightning as pl\n",
"\n",
"from datasets import load_dataset, load_metric\n",
"\n",
"from transformers import (\n",
" AutoModel,\n",
" AutoModelForSeq2SeqLM,\n",
" AutoTokenizer,\n",
" DataCollatorForSeq2Seq,\n",
" Seq2SeqTrainingArguments,\n",
" Seq2SeqTrainer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5fd7cb0c",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"t5-small\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "04530b1e",
"metadata": {},
"outputs": [],
"source": [
"# Define the LightningDataModule\n",
"class MyDataModule(pl.LightningDataModule):\n",
" def __init__(self, batch_size):\n",
" super().__init__()\n",
" self.batch_size = batch_size\n",
" \n",
" def prepare_data(self):\n",
" # Download and preprocess the data\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
" \n",
" def setup(self, stage=None):\n",
" # Load and preprocess the data\n",
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
"\n",
" self.train_ds = train_data.map(\n",
" self.preprocess_function, \n",
" batched=True, \n",
" batch_size=self.batch_size, \n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
" )\n",
"\n",
" self.val_ds = val_data.map(\n",
" self.preprocess_function, \n",
" batched=True, \n",
" batch_size=self.batch_size,\n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
" )\n",
"\n",
" def preprocess_function(self, batch):\n",
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
" batch[\"input_ids\"] = inputs.input_ids\n",
" batch[\"attention_mask\"] = inputs.attention_mask\n",
" batch[\"labels\"] = outputs.input_ids.copy()\n",
" return batch\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
"\n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fbb699e1",
"metadata": {},
"outputs": [],
"source": [
"class MyLightningModule(pl.LightningModule):\n",
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
" super().__init__()\n",
" self.model_name = model_name\n",
" self.learning_rate = learning_rate\n",
" self.weight_decay = weight_decay\n",
" self.batch_size = batch_size\n",
" \n",
" # Load the pre-trained model and tokenizer\n",
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
"\n",
" # Load the ROUGE metric\n",
" self.metric = load_metric(\"rouge\")\n",
"\n",
" def forward(self, input_ids, attention_mask, labels=None):\n",
" output = self.model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" labels=labels,\n",
" )\n",
" return output.loss, output.logits\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
" return {'loss': loss, 'logits': logits}\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" decoded_preds = []\n",
" decoded_labels = []\n",
" for output in outputs:\n",
" logits = output['logits']\n",
" labels = output['labels']\n",
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
" \n",
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
" \n",
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
" return optimizer\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dd63c628",
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (cuda), used: True\n",
"TPU available: False, using: 0 TPU cores\n",
"IPU available: False, using: 0 IPUs\n",
"HPU available: False, using: 0 HPUs\n",
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
"\n",
" 0%| | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
" 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
" 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
" 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
" 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
" 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
" 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
" 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
" 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
" 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
" 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
" 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
" 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
" 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
" 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
" 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
" 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
" 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
" 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
" 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
" 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
" 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
" 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
" 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
" 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
" 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
" 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
" 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
" 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
" 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
" 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
" 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
" 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
" 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
" 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
" 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
" 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
" 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
" 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
" 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
" 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
" 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
" 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
" 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
" 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
" 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
" 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
" 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
" 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
" 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
" 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
" 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
" 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
" 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
" 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
" 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
" 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
" 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
" 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
" 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
" 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
" 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
" 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
" 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
" 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
" 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
" 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
" 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
" 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
" 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
" 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
" 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
" 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
" 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
" 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
" 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
" 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
" 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
" 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
" 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
" 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
" 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
" 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
" 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
" 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
" 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
" 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
" 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
" 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
" 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
" 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
" 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
" 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
" 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
" 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
" 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
" 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
" 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
" 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
" 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
" 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
" 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
" 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
" 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
" 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
" 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
" 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
" 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
" 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
" 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
" 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
" 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
" 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
" 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
" 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
" 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
" 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
" 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
" 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
" 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
" 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
" 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
" 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
" 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
" 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
" 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
" 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
" 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
" 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
" 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
" 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
" 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
" 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
" 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
" 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
" 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
" 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
" 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
"\n",
" 0%| | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
" 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
" 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
" 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
" 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
" 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
" 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
"\n",
" | Name | Type | Params\n",
"-----------------------------------------------------\n",
"0 | model | T5ForConditionalGeneration | 60.5 M\n",
"-----------------------------------------------------\n",
"60.5 M Trainable params\n",
"0 Non-trainable params\n",
"60.5 M Total params\n",
"242.026 Total estimated model params size (MB)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute 'size'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
]
}
],
"source": [
"torch.set_float32_matmul_precision(\"medium\")\n",
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
"dm = MyDataModule(batch_size=16)\n",
"trainer.fit(model, datamodule=dm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1395d5d2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "80a2efab",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}