diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/Untitled-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/Untitled-checkpoint.ipynb
deleted file mode 100644
index 363fcab..0000000
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/Untitled-checkpoint.ipynb
+++ /dev/null
@@ -1,6 +0,0 @@
-{
- "cells": [],
- "metadata": {},
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/cnndaily_t5_lightning_customdataloading-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/cnndaily_t5_lightning_customdataloading-checkpoint.ipynb
deleted file mode 100644
index a3216e9..0000000
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/cnndaily_t5_lightning_customdataloading-checkpoint.ipynb
+++ /dev/null
@@ -1,317 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f54ecf0b",
- "metadata": {},
- "outputs": [],
- "source": [
- "\"\"\"\n",
- "# HuggingFace Tutorial Series\n",
- "- 1. What is Huggingface?\n",
- "- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
- "- 3. Using the HuggingFace Pipeline (High level feature)\n",
- "- 4. How the pipeline works at a lower level\n",
- "- 5. HuggingFace Datasets\n",
- "- 6. HuggingFace Tokenizer\n",
- "- 7. HuggingFace Evaluate\n",
- "- 8. HuggingFace Trainer\n",
- "- 9. Putting it together to finetune a news article summarizer\n",
- "- 10. Making it more general and robust with Lightning and custom data loading\n",
- "\"\"\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ec1aae37",
- "metadata": {},
- "outputs": [],
- "source": [
- "import warnings\n",
- "warnings.simplefilter(\"ignore\")\n",
- "\n",
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "import datasets \n",
- "import pytorch_lightning as pl\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")\n",
- "\n",
- "import torch\n",
- "import pandas as pd\n",
- "from torch.utils.data import Dataset\n",
- "import pytorch_lightning as pl\n",
- "\n",
- "torch.set_float32_matmul_precision(\"medium\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5fd7cb0c",
- "metadata": {},
- "outputs": [],
- "source": [
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "418cb03a",
- "metadata": {},
- "outputs": [],
- "source": [
- "class cnn_dailymail(Dataset):\n",
- " def __init__(self, csv_file, tokenizer, max_length=512):\n",
- " self.data = pd.read_csv(csv_file)\n",
- " self.tokenizer = tokenizer\n",
- " self.max_length = max_length\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.data)\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " article = self.data.loc[idx, 'article']\n",
- " highlights = self.data.loc[idx, 'highlights']\n",
- "\n",
- " inputs = self.tokenizer(\n",
- " article,\n",
- " truncation=True,\n",
- " padding='max_length',\n",
- " max_length=self.max_length,\n",
- " return_tensors='pt'\n",
- " )\n",
- " targets = self.tokenizer(\n",
- " highlights,\n",
- " truncation=True,\n",
- " padding='max_length',\n",
- " max_length=self.max_length,\n",
- " return_tensors='pt'\n",
- " )\n",
- "\n",
- " return {\n",
- " 'input_ids': inputs['input_ids'].squeeze(),\n",
- " 'attention_mask': inputs['attention_mask'].squeeze(),\n",
- " 'labels': targets['input_ids'].squeeze()\n",
- " }"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "aaa62755",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyDataModule(pl.LightningDataModule):\n",
- " def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
- " super().__init__()\n",
- " self.train_csv = train_csv\n",
- " self.val_csv = val_csv\n",
- " self.test_csv = test_csv\n",
- " self.tokenizer = tokenizer\n",
- " self.batch_size = batch_size\n",
- " self.max_length = max_length\n",
- "\n",
- " def setup(self, stage=None):\n",
- " if stage in ('fit', None):\n",
- " self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
- " self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
- " if stage in ('test', None):\n",
- " self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
- "\n",
- " def train_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
- "\n",
- " def test_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fbb699e1",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
- " \n",
- " # Load the ROUGE metric\n",
- " self.metric = load_metric(\"rouge\")\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
- " return {'loss': loss, 'logits': logits}\n",
- " \n",
- " def validation_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " \n",
- " # Save logits and labels as instance attributes\n",
- " if not hasattr(self, \"logits\"):\n",
- " self.logits = logits\n",
- " else:\n",
- " self.logits = torch.cat((self.logits, logits), dim=0)\n",
- " \n",
- " if not hasattr(self, \"labels\"):\n",
- " self.labels = labels\n",
- " else:\n",
- " self.labels = torch.cat((self.labels, labels), dim=0)\n",
- " \n",
- " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
- " \n",
- " def on_validation_epoch_end(self):\n",
- " # Convert logits to predicted token IDs\n",
- " pred_token_ids = self.logits.argmax(dim=-1)\n",
- "\n",
- " # Decode predictions and labels using the saved instance attributes\n",
- " decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
- "\n",
- " # Compute ROUGE scores\n",
- " scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- "\n",
- " self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
- " self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
- " self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
- "\n",
- " # Clear logits and labels instance attributes for the next validation epoch\n",
- " del self.logits\n",
- " del self.labels\n",
- " \n",
- " def configure_optimizers(self):\n",
- " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
- " return optimizer\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dd63c628",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# File paths\n",
- "train_csv = \"train.csv\"\n",
- "val_csv = \"validation.csv\"\n",
- "test_csv = \"test.csv\"\n",
- "\n",
- "# Create the data module\n",
- "dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
- "dm.setup()\n",
- "\n",
- "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
- "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
- "trainer.fit(model, datamodule=dm)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b5d3d684",
- "metadata": {},
- "outputs": [],
- "source": [
- "http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a0494596",
- "metadata": {},
- "source": [
- "### next steps:\n",
- "* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
- "\n",
- "#### what we've done:\n",
- "* Change the data loading so it's more general, meaning on the fly loading from disk\n",
- "* add torch.compile\n",
- "* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
- "* add tensorboard visualization\n",
- "* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
- "* 2. Create an inference step, send in news article -> get summary, check that it works\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80a2efab",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0f9b71ab",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
deleted file mode 100644
index e3220a5..0000000
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetune_t5_lightning-checkpoint.ipynb
+++ /dev/null
@@ -1,463 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "ec1aae37",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
- "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
- ]
- }
- ],
- "source": [
- "import warnings\n",
- "warnings.simplefilter(\"ignore\")\n",
- "\n",
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "import datasets \n",
- "import pytorch_lightning as pl\n",
- "\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "5fd7cb0c",
- "metadata": {},
- "outputs": [],
- "source": [
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "04530b1e",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define the LightningDataModule\n",
- "class MyDataModule(pl.LightningDataModule):\n",
- " def __init__(self, batch_size):\n",
- " super().__init__()\n",
- " self.batch_size = batch_size\n",
- " \n",
- " def prepare_data(self):\n",
- " # Download and preprocess the data\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- " \n",
- " def setup(self, stage=None):\n",
- " # Load and preprocess the data\n",
- " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
- " val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- "\n",
- " self.train_ds = train_data.map(\n",
- " self.preprocess_function, \n",
- " batched=True, \n",
- " batch_size=self.batch_size, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- " )\n",
- "\n",
- " self.val_ds = val_data.map(\n",
- " self.preprocess_function, \n",
- " batched=True, \n",
- " batch_size=self.batch_size,\n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- " )\n",
- "\n",
- " def preprocess_function(self, batch):\n",
- " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
- " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
- " batch[\"input_ids\"] = inputs.input_ids\n",
- " batch[\"attention_mask\"] = inputs.attention_mask\n",
- " batch[\"labels\"] = outputs.input_ids.copy()\n",
- " return batch\n",
- "\n",
- " def train_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "fbb699e1",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " self.batch_size = batch_size\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
- "\n",
- " # Load the ROUGE metric\n",
- " self.metric = load_metric(\"rouge\")\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits}\n",
- " \n",
- " def validation_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
- " \n",
- " def validation_epoch_end(self, outputs):\n",
- " decoded_preds = []\n",
- " decoded_labels = []\n",
- " for output in outputs:\n",
- " logits = output['logits']\n",
- " labels = output['labels']\n",
- " decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
- " decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " \n",
- " scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " \n",
- " self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
- " self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
- " self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
- " \n",
- " def configure_optimizers(self):\n",
- " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
- " return optimizer\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "dd63c628",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "GPU available: True (cuda), used: True\n",
- "TPU available: False, using: 0 TPU cores\n",
- "IPU available: False, using: 0 IPUs\n",
- "HPU available: False, using: 0 HPUs\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "\n",
- " 0%| | 0/1795 [00:00, ?ba/s]\u001b[A\n",
- " 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
- " 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
- " 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
- " 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
- " 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
- " 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
- " 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
- " 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
- " 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
- " 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
- " 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
- " 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
- " 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
- " 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
- " 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
- " 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
- " 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
- " 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
- " 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
- " 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
- " 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
- " 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
- " 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
- " 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
- " 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
- " 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
- " 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
- " 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
- " 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
- " 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
- " 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
- " 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
- " 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
- " 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
- " 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
- " 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
- " 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
- " 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
- " 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
- " 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
- " 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
- " 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
- " 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
- " 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
- " 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
- " 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
- " 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
- " 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
- " 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
- " 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
- " 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
- " 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
- " 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
- " 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
- " 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
- " 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
- " 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
- " 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
- " 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
- " 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
- " 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
- " 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
- " 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
- " 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
- " 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
- " 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
- " 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
- " 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
- " 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
- " 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
- " 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
- " 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
- " 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
- " 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
- " 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
- " 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
- " 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
- " 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
- " 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
- " 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
- " 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
- " 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
- " 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
- " 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
- " 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
- " 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
- " 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
- " 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
- " 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
- " 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
- " 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
- " 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
- " 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
- " 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
- " 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
- " 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
- " 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
- " 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
- " 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
- " 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
- " 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
- " 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
- " 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
- " 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
- " 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
- " 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
- " 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
- " 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
- " 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
- " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
- " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
- " 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
- " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
- " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
- " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
- " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
- " 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
- " 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
- " 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
- " 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
- " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
- " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
- " 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
- " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
- " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
- " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
- " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
- " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
- " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
- " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
- " 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
- " 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
- " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
- " 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
- " 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
- " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
- " 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
- " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
- " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
- " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
- " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
- " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
- " 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
- " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
- " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
- " 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
- " 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
- " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
- " 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
- " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
- " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
- "\n",
- " 0%| | 0/84 [00:00, ?ba/s]\u001b[A\n",
- " 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
- " 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
- " 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
- " 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
- " 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
- " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
- "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
- "\n",
- " | Name | Type | Params\n",
- "-----------------------------------------------------\n",
- "0 | model | T5ForConditionalGeneration | 60.5 M\n",
- "-----------------------------------------------------\n",
- "60.5 M Trainable params\n",
- "0 Non-trainable params\n",
- "60.5 M Total params\n",
- "242.026 Total estimated model params size (MB)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]"
- ]
- },
- {
- "ename": "AttributeError",
- "evalue": "'list' object has no attribute 'size'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
- "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
- ]
- }
- ],
- "source": [
- "torch.set_float32_matmul_precision(\"medium\")\n",
- "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
- "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
- "dm = MyDataModule(batch_size=16)\n",
- "trainer.fit(model, datamodule=dm)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1395d5d2",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80a2efab",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetuning_t5_small_cnndaily-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/finetuning_t5_small_cnndaily-checkpoint.ipynb
deleted file mode 100644
index 8cfe998..0000000
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/finetuning_t5_small_cnndaily-checkpoint.ipynb
+++ /dev/null
@@ -1,3585 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "bd8e3b95",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from jupyterthemes.stylefx import set_nb_theme\n",
- "set_nb_theme('chesterish')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "8c2a24cb",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "f45eb6b0",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "2023-02-21 15:40:52.888700: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
- "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2023-02-21 15:40:53.473104: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 15:40:53.473149: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 15:40:53.473154: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "import datasets \n",
- "\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForMaskedLM,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoModelForTokenClassification,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " pipeline,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "7fc4eb40",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
- "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
- "- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.\n",
- "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
- "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
- " warnings.warn(\n"
- ]
- }
- ],
- "source": [
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "363045f5",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1122/1122 [02:06<00:00, 8.88ba/s]\n",
- "Loading cached processed dataset at /home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de/cache-2d3b7edd75fb1188.arrow\n"
- ]
- }
- ],
- "source": [
- "def preprocess_function(batch):\n",
- " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
- " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
- " batch[\"input_ids\"] = inputs.input_ids\n",
- " batch[\"attention_mask\"] = inputs.attention_mask\n",
- " batch[\"labels\"] = outputs.input_ids.copy()\n",
- " return batch\n",
- "\n",
- "# Load the dataset\n",
- "train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
- "val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- "\n",
- "train_ds = train_data.map(\n",
- " preprocess_function, \n",
- " batched=True, \n",
- " batch_size=256, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- ")\n",
- "\n",
- "val_ds = val_data.map(\n",
- " preprocess_function, \n",
- " batched=True, \n",
- " batch_size=256, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "6faa8c86",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/tmp/ipykernel_478601/1088570042.py:23: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
- " metric = load_metric(\"rouge\")\n",
- "max_steps is given, it will override any value given in num_train_epochs\n",
- "Using cuda_amp half precision backend\n",
- "The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: id, article, highlights. If id, article, highlights are not expected by `T5ForConditionalGeneration.forward`, you can safely ignore this message.\n",
- "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
- " warnings.warn(\n",
- "***** Running training *****\n",
- " Num examples = 0\n",
- " Num Epochs = 1\n",
- " Instantaneous batch size per device = 16\n",
- " Total train batch size (w. parallel, distributed & accumulation) = 16\n",
- " Gradient Accumulation steps = 1\n",
- " Total optimization steps = 5000\n",
- " Number of trainable parameters = 60506624\n"
- ]
- },
- {
- "ename": "IndexError",
- "evalue": "Invalid key: 90427 is out of bounds for size 0",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[6], line 47\u001b[0m\n\u001b[1;32m 36\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 37\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 38\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 43\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics,\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# Start the training\u001b[39;00m\n\u001b[0;32m---> 47\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1536\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1538\u001b[0m )\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1761\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1758\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_load_rng_state(resume_from_checkpoint)\n\u001b[1;32m 1760\u001b[0m step \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1761\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, inputs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(epoch_iterator):\n\u001b[1;32m 1762\u001b[0m \n\u001b[1;32m 1763\u001b[0m \u001b[38;5;66;03m# Skip past any already trained steps if resuming training\u001b[39;00m\n\u001b[1;32m 1764\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m steps_trained_in_current_epoch \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1765\u001b[0m steps_trained_in_current_epoch \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2599\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key): \u001b[38;5;66;03m# noqa: F811\u001b[39;00m\n\u001b[1;32m 2600\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> 2601\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2585\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, **kwargs)\u001b[0m\n\u001b[1;32m 2583\u001b[0m format_kwargs \u001b[38;5;241m=\u001b[39m format_kwargs \u001b[38;5;28;01mif\u001b[39;00m format_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m {}\n\u001b[1;32m 2584\u001b[0m formatter \u001b[38;5;241m=\u001b[39m get_formatter(format_type, features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mformat_kwargs)\n\u001b[0;32m-> 2585\u001b[0m pa_subtable \u001b[38;5;241m=\u001b[39m \u001b[43mquery_table\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 2586\u001b[0m formatted_output \u001b[38;5;241m=\u001b[39m format_table(\n\u001b[1;32m 2587\u001b[0m pa_subtable, key, formatter\u001b[38;5;241m=\u001b[39mformatter, format_columns\u001b[38;5;241m=\u001b[39mformat_columns, output_all_columns\u001b[38;5;241m=\u001b[39moutput_all_columns\n\u001b[1;32m 2588\u001b[0m )\n\u001b[1;32m 2589\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m formatted_output\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:588\u001b[0m, in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 587\u001b[0m size \u001b[38;5;241m=\u001b[39m indices\u001b[38;5;241m.\u001b[39mnum_rows \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m table\u001b[38;5;241m.\u001b[39mnum_rows\n\u001b[0;32m--> 588\u001b[0m \u001b[43m_check_valid_index_key\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[38;5;66;03m# Query the main table\u001b[39;00m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:531\u001b[0m, in \u001b[0;36m_check_valid_index_key\u001b[0;34m(key, size)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (key \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m key \u001b[38;5;241m+\u001b[39m size \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (key \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m size):\n\u001b[0;32m--> 531\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid key: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is out of bounds for size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msize\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mslice\u001b[39m):\n",
- "\u001b[0;31mIndexError\u001b[0m: Invalid key: 90427 is out of bounds for size 0"
- ]
- }
- ],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " self.batch_size = batch_size\n",
- " self.num_training_steps = num_training_steps\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
- " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " \n",
- " loss\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " learning_rate=1e-5,\n",
- " per_device_train_batch_size=16,\n",
- " per_device_eval_batch_size=16,\n",
- " max_steps=5000,\n",
- " weight_decay=1e-4,\n",
- " push_to_hub=False,\n",
- " evaluation_strategy = \"steps\",\n",
- " eval_steps = 50,\n",
- " generation_max_length=128,\n",
- " predict_with_generate=True,\n",
- " logging_steps=100,\n",
- " gradient_accumulation_steps=1,\n",
- " fp16=True,\n",
- ")\n",
- "\n",
- "# Load the ROUGE metric\n",
- "metric = load_metric(\"rouge\")\n",
- "\n",
- "# Define the evaluation function\n",
- "def compute_metrics(pred):\n",
- " labels = pred.label_ids\n",
- " preds = pred.predictions\n",
- " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
- "\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=train_data,\n",
- " eval_dataset=val_data,\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "1b0f9a76",
- "metadata": {},
- "source": [
- "# Steps:\n",
- "1. Rewrite code to be more general\n",
- "\n",
- "a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
- "\n",
- "b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ff03c8bb",
- "metadata": {},
- "outputs": [],
- "source": [
- "!nvidia-smi"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "aafc4b27",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb b/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb
deleted file mode 100644
index c821b42..0000000
--- a/ML/Pytorch/huggingface/.ipynb_checkpoints/learning-checkpoint.ipynb
+++ /dev/null
@@ -1,644 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "7d5e92c6",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
- ]
- }
- ],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
- "from transformers import pipeline\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
- "model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
- "\n",
- "pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
- "example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
- "\n",
- "ner_entity_results = pipe(example)\n",
- "print(ner_entity_results)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "id": "bf67ee76",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Turtle s Original Cara mel Chocolate P eca n luster\n"
- ]
- }
- ],
- "source": [
- "ner_entity_results = pipe(example)\n",
- "\n",
- "# Initialize the entity words list with an empty string\n",
- "entity_words = [\"\"]\n",
- "\n",
- "# Loop through each dictionary in the list and extract the entity word\n",
- "for result in ner_entity_results:\n",
- " if result[\"entity\"] == \"B-FOOD\":\n",
- " entity_words.append(result[\"word\"])\n",
- " elif result[\"entity\"] == \"I-FOOD\":\n",
- " entity_words[-1] += \" \" + result[\"word\"]\n",
- "\n",
- "# Remove any remaining ## symbols and extra spaces\n",
- "entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
- "\n",
- "# Join the entity words into a single string\n",
- "output = \" \".join(entity_words)\n",
- "\n",
- "print(output)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fc8e5ea0",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "print(torch.cuda.is_available())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d8a1e039",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import pipeline\n",
- "import numpy as np"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6ad73024",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "classifier = pipeline(\"zero-shot-classification\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "04f7e02c",
- "metadata": {},
- "outputs": [],
- "source": [
- "classifier(\n",
- " \"This is a course about the Transformers library\",\n",
- " candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6fb246c2",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "from transformers import pipeline\n",
- "generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c4e174f0",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
- "import torch\n",
- "\n",
- "# Define input text and pre-trained model checkpoint\n",
- "text = \"My name is wolfgang and I live in berlin\"\n",
- "checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
- "\n",
- "# Instantiate tokenizer and encode input text\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
- "\n",
- "# Instantiate model and generate output\n",
- "model = AutoModel.from_pretrained(checkpoint)\n",
- "outputs = model(**inputs)\n",
- "print(outputs[0].shape)\n",
- "\n",
- "# Instantiate token classification model and generate predictions\n",
- "model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
- "outputs = model(**inputs)\n",
- "predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
- "print(predictions)\n",
- "print(model.config.id2label)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8212bbaa",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
- "model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
- "\n",
- "# prepare input\n",
- "text = \"Replace me by any text you'd like.\"\n",
- "encoded_input = tokenizer(text, return_tensors='pt')\n",
- "\n",
- "# forward pass\n",
- "output = model(**encoded_input)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "314cba41",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
- "\n",
- "# Load the pre-trained tokenizer and model\n",
- "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
- "model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
- "\n",
- "# Define the input sentence with a masked token\n",
- "text = \"I want to a new car tomorrow.\"\n",
- "\n",
- "# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
- "encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
- "\n",
- "print(output.logits.shape)\n",
- "print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
- "\n",
- "# Extract the predicted probabilities for the masked token\n",
- "predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
- "predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
- "\n",
- "# Get the top-k most probable predictions for the masked token\n",
- "k = 5\n",
- "top_k = torch.topk(predicted_probabilities, k)\n",
- "for i in range(k):\n",
- " token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
- " score = top_k.values[i].item()\n",
- " print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6187e77e",
- "metadata": {},
- "outputs": [],
- "source": [
- "%%time\n",
- "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
- "\n",
- "sequences = [\n",
- " \"Using a Transformer network is simple\",\n",
- " \"The quick brown fox jumps over the lazy dog\",\n",
- " \"To be or not to be, that is the question\"\n",
- "]\n",
- "\n",
- "# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
- "inputs = tokenizer(\n",
- " sequences,\n",
- " padding=True,\n",
- " truncation=True,\n",
- " return_tensors=\"pt\"\n",
- ")\n",
- "\n",
- "# Print the resulting input IDs and attention masks\n",
- "print(inputs['input_ids'])\n",
- "print(inputs['attention_mask'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fc259c5a",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "markdown",
- "id": "43466db6",
- "metadata": {},
- "source": [
- "Huggingface:\n",
- "\n",
- "1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
- "2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
- "3. Load dataset\n",
- "4. How to finetune\n",
- "5. How to evaluate\n",
- "6. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97c474f2",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3ed5d8c2",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
- "\n",
- "# Same as before\n",
- "checkpoint = \"bert-base-uncased\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
- "sequences = [\n",
- " \"I've been waiting for a HuggingFace course my whole life.\",\n",
- " \"This course is amazing!\",\n",
- "]\n",
- "batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
- "\n",
- "# This is new\n",
- "batch[\"labels\"] = torch.tensor([1, 1])\n",
- "\n",
- "optimizer = AdamW(model.parameters())\n",
- "loss = model(**batch).loss\n",
- "loss.backward()\n",
- "optimizer.step()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c598624f",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "raw_datasets = load_dataset(\"glue\", \"mrpc\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cd296227",
- "metadata": {},
- "outputs": [],
- "source": [
- "raw_train_dataset = raw_datasets[\"train\"]\n",
- "raw_train_dataset[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e462947a",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
- "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
- "\n",
- "checkpoint = \"bert-base-uncased\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "\n",
- "def tokenize_function(example):\n",
- " return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
- "\n",
- "\n",
- "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
- "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
- "\n",
- "\n",
- "from transformers import TrainingArguments\n",
- "training_args = TrainingArguments(\"test-trainer\")\n",
- "\n",
- "from transformers import AutoModelForSequenceClassification\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "import numpy as np\n",
- "import evaluate\n",
- "\n",
- "def compute_metrics(eval_preds):\n",
- " metric = evaluate.load(\"glue\", \"mrpc\")\n",
- " logits, labels = eval_preds\n",
- " predictions = np.argmax(logits, axis=-1)\n",
- " return metric.compute(predictions=predictions, references=labels)\n",
- "\n",
- "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "trainer = Trainer(\n",
- " model,\n",
- " training_args,\n",
- " train_dataset=tokenized_datasets[\"train\"],\n",
- " eval_dataset=tokenized_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0e2795dc",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import TrainingArguments\n",
- "training_args = TrainingArguments(\"test-trainer\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3af29cd5",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoModelForSequenceClassification\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "817f644e",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import evaluate"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "42819a6c",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "def compute_metrics(eval_preds):\n",
- " metric = evaluate.load(\"glue\", \"mrpc\")\n",
- " logits, labels = eval_preds\n",
- " predictions = np.argmax(logits, axis=-1)\n",
- " return metric.compute(predictions=predictions, references=labels)\n",
- "\n",
- "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "trainer = Trainer(\n",
- " model,\n",
- " training_args,\n",
- " train_dataset=tokenized_datasets[\"train\"],\n",
- " eval_dataset=tokenized_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "eb5986b0",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
- "from datasets import load_dataset\n",
- "batch_size=32\n",
- "\n",
- "# Define the generator function to preprocess the data in batches\n",
- "def preprocess_generator(examples):\n",
- " for i in range(0, len(examples[\"article\"]), batch_size):\n",
- " batch = examples[\"article\"][i:i+batch_size]\n",
- " targets = examples[\"highlights\"][i:i+batch_size]\n",
- " model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
- " with tokenizer.as_target_tokenizer():\n",
- " model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
- " model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
- " yield model_inputs\n",
- "\n",
- "def preprocess_function(examples):\n",
- " articles = [ex for ex in examples[\"article\"]]\n",
- " summaries = [ex for ex in examples[\"highlights\"]]\n",
- "\n",
- " model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
- " with tokenizer.as_target_tokenizer():\n",
- " model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
- " \n",
- " model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
- " return model_inputs\n",
- " \n",
- "# Load the dataset\n",
- "raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
- "preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
- "\n",
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " evaluation_strategy = \"epoch\",\n",
- " learning_rate=2e-5,\n",
- " per_device_train_batch_size=batch_size,\n",
- " max_steps=1000,\n",
- " weight_decay=0.01,\n",
- " push_to_hub=False,\n",
- ")\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=train_ds,\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7d62583e",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_metric"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d310a7b3",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "preprocessed_datasets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "99d422cc",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " learning_rate=2e-5,\n",
- " per_device_train_batch_size=batch_size,\n",
- " max_steps=5000,\n",
- " weight_decay=0.01,\n",
- " push_to_hub=False,\n",
- " evaluation_strategy = \"steps\",\n",
- " eval_steps = 50,\n",
- ")\n",
- "\n",
- "# Load the ROUGE metric\n",
- "metric = load_metric(\"rouge\")\n",
- "\n",
- "# Define the evaluation function\n",
- "def compute_metrics(pred):\n",
- " labels = pred.label_ids\n",
- " preds = pred.predictions\n",
- " \n",
- " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " \n",
- " scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " \n",
- " return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
- "\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=preprocessed_datasets[\"train\"],\n",
- " eval_dataset=preprocessed_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a5e97b57",
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install nltk\n",
- "!pip install rouge_score"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "558c3e66",
- "metadata": {},
- "source": [
- "# Goal:\n",
- "\n",
- "1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
- "* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
- "\n",
- "2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "624d49ca",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/cnndaily_t5_lightning_customdataloading.ipynb b/ML/Pytorch/huggingface/cnndaily_t5_lightning_customdataloading.ipynb
deleted file mode 100644
index a3216e9..0000000
--- a/ML/Pytorch/huggingface/cnndaily_t5_lightning_customdataloading.ipynb
+++ /dev/null
@@ -1,317 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f54ecf0b",
- "metadata": {},
- "outputs": [],
- "source": [
- "\"\"\"\n",
- "# HuggingFace Tutorial Series\n",
- "- 1. What is Huggingface?\n",
- "- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
- "- 3. Using the HuggingFace Pipeline (High level feature)\n",
- "- 4. How the pipeline works at a lower level\n",
- "- 5. HuggingFace Datasets\n",
- "- 6. HuggingFace Tokenizer\n",
- "- 7. HuggingFace Evaluate\n",
- "- 8. HuggingFace Trainer\n",
- "- 9. Putting it together to finetune a news article summarizer\n",
- "- 10. Making it more general and robust with Lightning and custom data loading\n",
- "\"\"\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "ec1aae37",
- "metadata": {},
- "outputs": [],
- "source": [
- "import warnings\n",
- "warnings.simplefilter(\"ignore\")\n",
- "\n",
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "import datasets \n",
- "import pytorch_lightning as pl\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")\n",
- "\n",
- "import torch\n",
- "import pandas as pd\n",
- "from torch.utils.data import Dataset\n",
- "import pytorch_lightning as pl\n",
- "\n",
- "torch.set_float32_matmul_precision(\"medium\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5fd7cb0c",
- "metadata": {},
- "outputs": [],
- "source": [
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "418cb03a",
- "metadata": {},
- "outputs": [],
- "source": [
- "class cnn_dailymail(Dataset):\n",
- " def __init__(self, csv_file, tokenizer, max_length=512):\n",
- " self.data = pd.read_csv(csv_file)\n",
- " self.tokenizer = tokenizer\n",
- " self.max_length = max_length\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.data)\n",
- "\n",
- " def __getitem__(self, idx):\n",
- " article = self.data.loc[idx, 'article']\n",
- " highlights = self.data.loc[idx, 'highlights']\n",
- "\n",
- " inputs = self.tokenizer(\n",
- " article,\n",
- " truncation=True,\n",
- " padding='max_length',\n",
- " max_length=self.max_length,\n",
- " return_tensors='pt'\n",
- " )\n",
- " targets = self.tokenizer(\n",
- " highlights,\n",
- " truncation=True,\n",
- " padding='max_length',\n",
- " max_length=self.max_length,\n",
- " return_tensors='pt'\n",
- " )\n",
- "\n",
- " return {\n",
- " 'input_ids': inputs['input_ids'].squeeze(),\n",
- " 'attention_mask': inputs['attention_mask'].squeeze(),\n",
- " 'labels': targets['input_ids'].squeeze()\n",
- " }"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "aaa62755",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyDataModule(pl.LightningDataModule):\n",
- " def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
- " super().__init__()\n",
- " self.train_csv = train_csv\n",
- " self.val_csv = val_csv\n",
- " self.test_csv = test_csv\n",
- " self.tokenizer = tokenizer\n",
- " self.batch_size = batch_size\n",
- " self.max_length = max_length\n",
- "\n",
- " def setup(self, stage=None):\n",
- " if stage in ('fit', None):\n",
- " self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
- " self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
- " if stage in ('test', None):\n",
- " self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
- "\n",
- " def train_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
- "\n",
- " def test_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fbb699e1",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
- " \n",
- " # Load the ROUGE metric\n",
- " self.metric = load_metric(\"rouge\")\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
- " return {'loss': loss, 'logits': logits}\n",
- " \n",
- " def validation_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " \n",
- " # Save logits and labels as instance attributes\n",
- " if not hasattr(self, \"logits\"):\n",
- " self.logits = logits\n",
- " else:\n",
- " self.logits = torch.cat((self.logits, logits), dim=0)\n",
- " \n",
- " if not hasattr(self, \"labels\"):\n",
- " self.labels = labels\n",
- " else:\n",
- " self.labels = torch.cat((self.labels, labels), dim=0)\n",
- " \n",
- " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
- " \n",
- " def on_validation_epoch_end(self):\n",
- " # Convert logits to predicted token IDs\n",
- " pred_token_ids = self.logits.argmax(dim=-1)\n",
- "\n",
- " # Decode predictions and labels using the saved instance attributes\n",
- " decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
- "\n",
- " # Compute ROUGE scores\n",
- " scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- "\n",
- " self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
- " self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
- " self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
- "\n",
- " # Clear logits and labels instance attributes for the next validation epoch\n",
- " del self.logits\n",
- " del self.labels\n",
- " \n",
- " def configure_optimizers(self):\n",
- " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
- " return optimizer\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "dd63c628",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "# File paths\n",
- "train_csv = \"train.csv\"\n",
- "val_csv = \"validation.csv\"\n",
- "test_csv = \"test.csv\"\n",
- "\n",
- "# Create the data module\n",
- "dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
- "dm.setup()\n",
- "\n",
- "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
- "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
- "trainer.fit(model, datamodule=dm)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b5d3d684",
- "metadata": {},
- "outputs": [],
- "source": [
- "http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a0494596",
- "metadata": {},
- "source": [
- "### next steps:\n",
- "* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
- "\n",
- "#### what we've done:\n",
- "* Change the data loading so it's more general, meaning on the fly loading from disk\n",
- "* add torch.compile\n",
- "* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
- "* add tensorboard visualization\n",
- "* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
- "* 2. Create an inference step, send in news article -> get summary, check that it works\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80a2efab",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0f9b71ab",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/dataset.py b/ML/Pytorch/huggingface/dataset.py
deleted file mode 100644
index 8255e21..0000000
--- a/ML/Pytorch/huggingface/dataset.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import pandas as pd
-import pytorch_lightning as pl
-from torch.utils.data import Dataset
-import torch
-
-
-class cnn_dailymail(Dataset):
- def __init__(self, csv_file, tokenizer, max_length=512):
- self.data = pd.read_csv(csv_file)
-
- # if the csv_file is "train.csv" then only take out 10% of the data. make sure to reset indices etc
- #if csv_file == "train.csv":
- # self.data = self.data.sample(frac=0.05, random_state=42).reset_index(drop=True)
-
- self.tokenizer = tokenizer
- self.max_length = max_length
-
- def __len__(self):
- return len(self.data)
-
- def __getitem__(self, idx):
- article = self.data.loc[idx, "article"]
- highlights = self.data.loc[idx, "highlights"]
-
- inputs = self.tokenizer(
- article,
- truncation=True,
- padding="max_length",
- max_length=self.max_length,
- return_tensors="pt",
- )
- targets = self.tokenizer(
- highlights,
- truncation=True,
- padding="max_length",
- max_length=self.max_length,
- return_tensors="pt",
- )
-
- return {
- "input_ids": inputs["input_ids"].squeeze(),
- "attention_mask": inputs["attention_mask"].squeeze(),
- "labels": targets["input_ids"].squeeze(),
- }
-
-
-class MyDataModule(pl.LightningDataModule):
- def __init__(
- self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512
- ):
- super().__init__()
- self.train_csv = train_csv
- self.val_csv = val_csv
- self.test_csv = test_csv
- self.tokenizer = tokenizer
- self.batch_size = batch_size
- self.max_length = max_length
-
- def setup(self, stage=None):
- if stage in ("fit", None):
- self.train_dataset = cnn_dailymail(
- self.train_csv, self.tokenizer, self.max_length
- )
- self.val_dataset = cnn_dailymail(
- self.val_csv, self.tokenizer, self.max_length
- )
- if stage in ("test", None):
- self.test_dataset = cnn_dailymail(
- self.test_csv, self.tokenizer, self.max_length
- )
-
- def train_dataloader(self):
- return torch.utils.data.DataLoader(
- self.train_dataset,
- batch_size=self.batch_size,
- pin_memory=True,
- shuffle=True,
- num_workers=6,
- )
-
- def val_dataloader(self):
- return torch.utils.data.DataLoader(
- self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
- )
-
- def test_dataloader(self):
- return torch.utils.data.DataLoader(
- self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1
- )
diff --git a/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb b/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb
deleted file mode 100644
index da1dc98..0000000
--- a/ML/Pytorch/huggingface/finetune_t5_lightning.ipynb
+++ /dev/null
@@ -1,470 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "ec1aae37",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
- "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
- "2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
- ]
- }
- ],
- "source": [
- "import warnings\n",
- "warnings.simplefilter(\"ignore\")\n",
- "\n",
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
- "\n",
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "import datasets \n",
- "import pytorch_lightning as pl\n",
- "\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "5fd7cb0c",
- "metadata": {},
- "outputs": [],
- "source": [
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "04530b1e",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Define the LightningDataModule\n",
- "class MyDataModule(pl.LightningDataModule):\n",
- " def __init__(self, batch_size):\n",
- " super().__init__()\n",
- " self.batch_size = batch_size\n",
- " \n",
- " def prepare_data(self):\n",
- " # Download and preprocess the data\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
- " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- " \n",
- " def setup(self, stage=None):\n",
- " # Load and preprocess the data\n",
- " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
- " val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- "\n",
- " self.train_ds = train_data.map(\n",
- " self.preprocess_function, \n",
- " batched=True, \n",
- " batch_size=self.batch_size, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- " )\n",
- "\n",
- " self.val_ds = val_data.map(\n",
- " self.preprocess_function, \n",
- " batched=True, \n",
- " batch_size=self.batch_size,\n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- " )\n",
- "\n",
- " def preprocess_function(self, batch):\n",
- " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
- " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
- " batch[\"input_ids\"] = inputs.input_ids\n",
- " batch[\"attention_mask\"] = inputs.attention_mask\n",
- " batch[\"labels\"] = outputs.input_ids.copy()\n",
- " return batch\n",
- "\n",
- " def train_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
- "\n",
- " def val_dataloader(self):\n",
- " return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "fbb699e1",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " self.batch_size = batch_size\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
- "\n",
- " # Load the ROUGE metric\n",
- " self.metric = load_metric(\"rouge\")\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits}\n",
- " \n",
- " def validation_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " loss, logits = self(input_ids, attention_mask, labels)\n",
- " self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
- " return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
- " \n",
- " def validation_epoch_end(self, outputs):\n",
- " decoded_preds = []\n",
- " decoded_labels = []\n",
- " for output in outputs:\n",
- " logits = output['logits']\n",
- " labels = output['labels']\n",
- " decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
- " decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " \n",
- " scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " \n",
- " self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
- " self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
- " self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
- " \n",
- " def configure_optimizers(self):\n",
- " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
- " return optimizer\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "dd63c628",
- "metadata": {
- "scrolled": false
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "GPU available: True (cuda), used: True\n",
- "TPU available: False, using: 0 TPU cores\n",
- "IPU available: False, using: 0 IPUs\n",
- "HPU available: False, using: 0 HPUs\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
- "\n",
- " 0%| | 0/1795 [00:00, ?ba/s]\u001b[A\n",
- " 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
- " 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
- " 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
- " 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
- " 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
- " 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
- " 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
- " 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
- " 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
- " 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
- " 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
- " 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
- " 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
- " 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
- " 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
- " 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
- " 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
- " 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
- " 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
- " 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
- " 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
- " 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
- " 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
- " 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
- " 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
- " 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
- " 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
- " 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
- " 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
- " 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
- " 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
- " 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
- " 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
- " 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
- " 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
- " 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
- " 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
- " 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
- " 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
- " 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
- " 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
- " 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
- " 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
- " 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
- " 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
- " 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
- " 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
- " 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
- " 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
- " 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
- " 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
- " 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
- " 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
- " 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
- " 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
- " 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
- " 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
- " 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
- " 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
- " 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
- " 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
- " 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
- " 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
- " 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
- " 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
- " 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
- " 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
- " 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
- " 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
- " 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
- " 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
- " 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
- " 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
- " 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
- " 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
- " 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
- " 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
- " 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
- " 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
- " 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
- " 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
- " 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
- " 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
- " 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
- " 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
- " 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
- " 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
- " 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
- " 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
- " 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
- " 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
- " 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
- " 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
- " 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
- " 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
- " 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
- " 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
- " 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
- " 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
- " 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
- " 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
- " 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
- " 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
- " 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
- " 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
- " 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
- " 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
- " 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
- " 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
- " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
- " 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
- " 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
- " 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
- " 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
- " 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
- " 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
- " 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
- " 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
- " 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
- " 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
- " 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
- " 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
- " 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
- " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
- " 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
- " 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
- " 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
- " 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
- " 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
- " 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
- " 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
- " 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
- " 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
- " 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
- " 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
- " 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
- " 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
- " 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
- " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
- " 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
- " 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
- " 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
- " 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
- " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
- " 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
- " 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
- " 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
- " 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
- " 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
- " 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
- " 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
- "\n",
- " 0%| | 0/84 [00:00, ?ba/s]\u001b[A\n",
- " 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
- " 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
- " 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
- " 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
- " 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
- " 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
- "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
- "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
- "\n",
- " | Name | Type | Params\n",
- "-----------------------------------------------------\n",
- "0 | model | T5ForConditionalGeneration | 60.5 M\n",
- "-----------------------------------------------------\n",
- "60.5 M Trainable params\n",
- "0 Non-trainable params\n",
- "60.5 M Total params\n",
- "242.026 Total estimated model params size (MB)\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s]"
- ]
- },
- {
- "ename": "AttributeError",
- "evalue": "'list' object has no attribute 'size'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
- "Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
- "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
- "\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
- ]
- }
- ],
- "source": [
- "torch.set_float32_matmul_precision(\"medium\")\n",
- "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
- "trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
- "dm = MyDataModule(batch_size=16)\n",
- "trainer.fit(model, datamodule=dm)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "aa7b1ab0",
- "metadata": {},
- "source": [
- "### Recap of what we did:\n",
- "* Finetuned T5-Small on DailyCNN (summarize news articles) using HF Trainer and data loading\n",
- "* Converted to Lightning code \n",
- "\n",
- "### To do next:\n",
- "* Make it work with the evaluation somethings wrong now, don't think it's a big issue\n",
- "* Clean up the code a bit\n",
- "* Compare it with HF, add predict function, modify data loading so it's from scratch / more general way of doing it."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "80a2efab",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb b/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
deleted file mode 100644
index 09bebc9..0000000
--- a/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
+++ /dev/null
@@ -1,237 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5372055b",
- "metadata": {},
- "outputs": [],
- "source": [
- "from jupyterthemes.stylefx import set_nb_theme\n",
- "set_nb_theme('chesterish')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "11214a4a",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "f45eb6b0",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "import datasets \n",
- "\n",
- "from datasets import load_dataset, load_metric\n",
- "\n",
- "from transformers import (\n",
- " AutoModel,\n",
- " AutoModelForMaskedLM,\n",
- " AutoModelForSeq2SeqLM,\n",
- " AutoModelForTokenClassification,\n",
- " AutoTokenizer,\n",
- " DataCollatorForSeq2Seq,\n",
- " pipeline,\n",
- " Seq2SeqTrainingArguments,\n",
- " Seq2SeqTrainer,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "b2d26af4",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "363045f5",
- "metadata": {},
- "outputs": [],
- "source": [
- "def preprocess_function(batch):\n",
- " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
- " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
- " batch[\"input_ids\"] = inputs.input_ids\n",
- " batch[\"attention_mask\"] = inputs.attention_mask\n",
- " batch[\"labels\"] = outputs.input_ids.copy()\n",
- " return batch\n",
- "\n",
- "# Load the dataset\n",
- "train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
- "val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
- "\n",
- "train_ds = train_data.map(\n",
- " preprocess_function, \n",
- " batched=True, \n",
- " batch_size=256, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- ")\n",
- "\n",
- "val_ds = val_data.map(\n",
- " preprocess_function, \n",
- " batched=True, \n",
- " batch_size=256, \n",
- " remove_columns=[\"article\", \"highlights\", \"id\"]\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0d58818f",
- "metadata": {},
- "outputs": [],
- "source": [
- "class MyLightningModule(pl.LightningModule):\n",
- " def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
- " super().__init__()\n",
- " self.model_name = model_name\n",
- " self.learning_rate = learning_rate\n",
- " self.weight_decay = weight_decay\n",
- " self.batch_size = batch_size\n",
- " self.num_training_steps = num_training_steps\n",
- " \n",
- " # Load the pre-trained model and tokenizer\n",
- " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
- " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
- "\n",
- " def forward(self, input_ids, attention_mask, labels=None):\n",
- " output = self.model(\n",
- " input_ids=input_ids,\n",
- " attention_mask=attention_mask,\n",
- " labels=labels,\n",
- " )\n",
- " return output.loss, output.logits\n",
- " \n",
- " def training_step(self, batch, batch_idx):\n",
- " input_ids = batch[\"input_ids\"]\n",
- " attention_mask = batch[\"attention_mask\"]\n",
- " labels = batch[\"labels\"]\n",
- " \n",
- " loss\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " learning_rate=1e-5,\n",
- " per_device_train_batch_size=16,\n",
- " per_device_eval_batch_size=16,\n",
- " max_steps=5000,\n",
- " weight_decay=1e-4,\n",
- " push_to_hub=False,\n",
- " evaluation_strategy = \"steps\",\n",
- " eval_steps = 50,\n",
- " generation_max_length=128,\n",
- " predict_with_generate=True,\n",
- " logging_steps=100,\n",
- " gradient_accumulation_steps=1,\n",
- " fp16=True,\n",
- ")\n",
- "\n",
- "# Load the ROUGE metric\n",
- "metric = load_metric(\"rouge\")\n",
- "\n",
- "# Define the evaluation function\n",
- "def compute_metrics(pred):\n",
- " labels = pred.label_ids\n",
- " preds = pred.predictions\n",
- " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
- "\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=train_data,\n",
- " eval_dataset=val_data,\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5148159b",
- "metadata": {},
- "source": [
- "# Steps:\n",
- "1. Rewrite code to be more general\n",
- "\n",
- "a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
- "\n",
- "b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "95e33e40",
- "metadata": {},
- "outputs": [],
- "source": [
- "!nvidia-smi"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "4c0348c2",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/learning.ipynb b/ML/Pytorch/huggingface/learning.ipynb
deleted file mode 100644
index c821b42..0000000
--- a/ML/Pytorch/huggingface/learning.ipynb
+++ /dev/null
@@ -1,644 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "7d5e92c6",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
- ]
- }
- ],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
- "from transformers import pipeline\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
- "model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
- "\n",
- "pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
- "example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
- "\n",
- "ner_entity_results = pipe(example)\n",
- "print(ner_entity_results)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "id": "bf67ee76",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Turtle s Original Cara mel Chocolate P eca n luster\n"
- ]
- }
- ],
- "source": [
- "ner_entity_results = pipe(example)\n",
- "\n",
- "# Initialize the entity words list with an empty string\n",
- "entity_words = [\"\"]\n",
- "\n",
- "# Loop through each dictionary in the list and extract the entity word\n",
- "for result in ner_entity_results:\n",
- " if result[\"entity\"] == \"B-FOOD\":\n",
- " entity_words.append(result[\"word\"])\n",
- " elif result[\"entity\"] == \"I-FOOD\":\n",
- " entity_words[-1] += \" \" + result[\"word\"]\n",
- "\n",
- "# Remove any remaining ## symbols and extra spaces\n",
- "entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
- "\n",
- "# Join the entity words into a single string\n",
- "output = \" \".join(entity_words)\n",
- "\n",
- "print(output)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fc8e5ea0",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "print(torch.cuda.is_available())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d8a1e039",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import pipeline\n",
- "import numpy as np"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6ad73024",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "classifier = pipeline(\"zero-shot-classification\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "04f7e02c",
- "metadata": {},
- "outputs": [],
- "source": [
- "classifier(\n",
- " \"This is a course about the Transformers library\",\n",
- " candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6fb246c2",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "from transformers import pipeline\n",
- "generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c4e174f0",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
- "import torch\n",
- "\n",
- "# Define input text and pre-trained model checkpoint\n",
- "text = \"My name is wolfgang and I live in berlin\"\n",
- "checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
- "\n",
- "# Instantiate tokenizer and encode input text\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
- "\n",
- "# Instantiate model and generate output\n",
- "model = AutoModel.from_pretrained(checkpoint)\n",
- "outputs = model(**inputs)\n",
- "print(outputs[0].shape)\n",
- "\n",
- "# Instantiate token classification model and generate predictions\n",
- "model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
- "outputs = model(**inputs)\n",
- "predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
- "print(predictions)\n",
- "print(model.config.id2label)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "8212bbaa",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
- "\n",
- "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
- "model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
- "\n",
- "# prepare input\n",
- "text = \"Replace me by any text you'd like.\"\n",
- "encoded_input = tokenizer(text, return_tensors='pt')\n",
- "\n",
- "# forward pass\n",
- "output = model(**encoded_input)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "314cba41",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
- "\n",
- "# Load the pre-trained tokenizer and model\n",
- "tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
- "model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
- "\n",
- "# Define the input sentence with a masked token\n",
- "text = \"I want to a new car tomorrow.\"\n",
- "\n",
- "# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
- "encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
- "\n",
- "print(output.logits.shape)\n",
- "print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
- "\n",
- "# Extract the predicted probabilities for the masked token\n",
- "predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
- "predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
- "\n",
- "# Get the top-k most probable predictions for the masked token\n",
- "k = 5\n",
- "top_k = torch.topk(predicted_probabilities, k)\n",
- "for i in range(k):\n",
- " token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
- " score = top_k.values[i].item()\n",
- " print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6187e77e",
- "metadata": {},
- "outputs": [],
- "source": [
- "%%time\n",
- "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
- "\n",
- "sequences = [\n",
- " \"Using a Transformer network is simple\",\n",
- " \"The quick brown fox jumps over the lazy dog\",\n",
- " \"To be or not to be, that is the question\"\n",
- "]\n",
- "\n",
- "# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
- "inputs = tokenizer(\n",
- " sequences,\n",
- " padding=True,\n",
- " truncation=True,\n",
- " return_tensors=\"pt\"\n",
- ")\n",
- "\n",
- "# Print the resulting input IDs and attention masks\n",
- "print(inputs['input_ids'])\n",
- "print(inputs['attention_mask'])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "fc259c5a",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "markdown",
- "id": "43466db6",
- "metadata": {},
- "source": [
- "Huggingface:\n",
- "\n",
- "1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
- "2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
- "3. Load dataset\n",
- "4. How to finetune\n",
- "5. How to evaluate\n",
- "6. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "97c474f2",
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3ed5d8c2",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
- "\n",
- "# Same as before\n",
- "checkpoint = \"bert-base-uncased\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
- "sequences = [\n",
- " \"I've been waiting for a HuggingFace course my whole life.\",\n",
- " \"This course is amazing!\",\n",
- "]\n",
- "batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
- "\n",
- "# This is new\n",
- "batch[\"labels\"] = torch.tensor([1, 1])\n",
- "\n",
- "optimizer = AdamW(model.parameters())\n",
- "loss = model(**batch).loss\n",
- "loss.backward()\n",
- "optimizer.step()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "c598624f",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "raw_datasets = load_dataset(\"glue\", \"mrpc\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "cd296227",
- "metadata": {},
- "outputs": [],
- "source": [
- "raw_train_dataset = raw_datasets[\"train\"]\n",
- "raw_train_dataset[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "e462947a",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_dataset\n",
- "from transformers import AutoTokenizer, DataCollatorWithPadding\n",
- "raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
- "\n",
- "checkpoint = \"bert-base-uncased\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
- "\n",
- "def tokenize_function(example):\n",
- " return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
- "\n",
- "\n",
- "tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
- "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
- "\n",
- "\n",
- "from transformers import TrainingArguments\n",
- "training_args = TrainingArguments(\"test-trainer\")\n",
- "\n",
- "from transformers import AutoModelForSequenceClassification\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "import numpy as np\n",
- "import evaluate\n",
- "\n",
- "def compute_metrics(eval_preds):\n",
- " metric = evaluate.load(\"glue\", \"mrpc\")\n",
- " logits, labels = eval_preds\n",
- " predictions = np.argmax(logits, axis=-1)\n",
- " return metric.compute(predictions=predictions, references=labels)\n",
- "\n",
- "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "trainer = Trainer(\n",
- " model,\n",
- " training_args,\n",
- " train_dataset=tokenized_datasets[\"train\"],\n",
- " eval_dataset=tokenized_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "0e2795dc",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import TrainingArguments\n",
- "training_args = TrainingArguments(\"test-trainer\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "3af29cd5",
- "metadata": {},
- "outputs": [],
- "source": [
- "from transformers import AutoModelForSequenceClassification\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "817f644e",
- "metadata": {},
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import evaluate"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "42819a6c",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "def compute_metrics(eval_preds):\n",
- " metric = evaluate.load(\"glue\", \"mrpc\")\n",
- " logits, labels = eval_preds\n",
- " predictions = np.argmax(logits, axis=-1)\n",
- " return metric.compute(predictions=predictions, references=labels)\n",
- "\n",
- "training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
- "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
- "\n",
- "trainer = Trainer(\n",
- " model,\n",
- " training_args,\n",
- " train_dataset=tokenized_datasets[\"train\"],\n",
- " eval_dataset=tokenized_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "eb5986b0",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
- "from datasets import load_dataset\n",
- "batch_size=32\n",
- "\n",
- "# Define the generator function to preprocess the data in batches\n",
- "def preprocess_generator(examples):\n",
- " for i in range(0, len(examples[\"article\"]), batch_size):\n",
- " batch = examples[\"article\"][i:i+batch_size]\n",
- " targets = examples[\"highlights\"][i:i+batch_size]\n",
- " model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
- " with tokenizer.as_target_tokenizer():\n",
- " model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
- " model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
- " yield model_inputs\n",
- "\n",
- "def preprocess_function(examples):\n",
- " articles = [ex for ex in examples[\"article\"]]\n",
- " summaries = [ex for ex in examples[\"highlights\"]]\n",
- "\n",
- " model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
- " with tokenizer.as_target_tokenizer():\n",
- " model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
- " \n",
- " model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
- " return model_inputs\n",
- " \n",
- "# Load the dataset\n",
- "raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
- "preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
- "\n",
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " evaluation_strategy = \"epoch\",\n",
- " learning_rate=2e-5,\n",
- " per_device_train_batch_size=batch_size,\n",
- " max_steps=1000,\n",
- " weight_decay=0.01,\n",
- " push_to_hub=False,\n",
- ")\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=train_ds,\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "7d62583e",
- "metadata": {},
- "outputs": [],
- "source": [
- "from datasets import load_metric"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "d310a7b3",
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "preprocessed_datasets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "99d422cc",
- "metadata": {
- "scrolled": false
- },
- "outputs": [],
- "source": [
- "# Load the pre-trained model and tokenizer\n",
- "model_name = \"t5-small\"\n",
- "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
- "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
- "\n",
- "# Define the data collator\n",
- "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
- "\n",
- "# Initialize the trainer arguments\n",
- "training_args = Seq2SeqTrainingArguments(\n",
- " output_dir=\"./results\",\n",
- " learning_rate=2e-5,\n",
- " per_device_train_batch_size=batch_size,\n",
- " max_steps=5000,\n",
- " weight_decay=0.01,\n",
- " push_to_hub=False,\n",
- " evaluation_strategy = \"steps\",\n",
- " eval_steps = 50,\n",
- ")\n",
- "\n",
- "# Load the ROUGE metric\n",
- "metric = load_metric(\"rouge\")\n",
- "\n",
- "# Define the evaluation function\n",
- "def compute_metrics(pred):\n",
- " labels = pred.label_ids\n",
- " preds = pred.predictions\n",
- " \n",
- " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
- " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
- " \n",
- " scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
- " \n",
- " return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
- "\n",
- "\n",
- "# Initialize the trainer\n",
- "trainer = Seq2SeqTrainer(\n",
- " model=model,\n",
- " args=training_args,\n",
- " train_dataset=preprocessed_datasets[\"train\"],\n",
- " eval_dataset=preprocessed_datasets[\"validation\"],\n",
- " data_collator=data_collator,\n",
- " tokenizer=tokenizer,\n",
- " compute_metrics=compute_metrics,\n",
- ")\n",
- "\n",
- "# Start the training\n",
- "trainer.train()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "a5e97b57",
- "metadata": {},
- "outputs": [],
- "source": [
- "!pip install nltk\n",
- "!pip install rouge_score"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "558c3e66",
- "metadata": {},
- "source": [
- "# Goal:\n",
- "\n",
- "1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
- "* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
- "\n",
- "2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "624d49ca",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/ML/Pytorch/huggingface/learninghugg.py b/ML/Pytorch/huggingface/learninghugg.py
deleted file mode 100644
index 236cf3b..0000000
--- a/ML/Pytorch/huggingface/learninghugg.py
+++ /dev/null
@@ -1,41 +0,0 @@
-from datasets import load_dataset
-from transformers import AutoTokenizer, DataCollatorWithPadding
-from transformers import Trainer
-
-raw_datasets = load_dataset("glue", "mrpc")
-checkpoint = "bert-base-uncased"
-tokenizer = AutoTokenizer.from_pretrained(checkpoint)
-
-
-def tokenize_function(example):
- return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
-
-
-tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
-data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
-
-
-from transformers import TrainingArguments
-training_args = TrainingArguments("test-trainer")
-
-from transformers import AutoModelForSequenceClassification
-model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
-
-def compute_metrics(eval_preds):
- metric = evaluate.load("glue", "mrpc")
- logits, labels = eval_preds
- predictions = np.argmax(logits, axis=-1)
- return metric.compute(predictions=predictions, references=labels)
-
-training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
-model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
-
-trainer = Trainer(
- model,
- training_args,
- train_dataset=tokenized_datasets["train"],
- eval_dataset=tokenized_datasets["validation"],
- data_collator=data_collator,
- tokenizer=tokenizer,
- compute_metrics=compute_metrics,
-)
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0 b/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0
deleted file mode 100644
index ab84c26..0000000
Binary files a/ML/Pytorch/huggingface/lightning_logs/version_0/events.out.tfevents.1676993704.mrbeast.566861.0 and /dev/null differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml
deleted file mode 100644
index 0967ef4..0000000
--- a/ML/Pytorch/huggingface/lightning_logs/version_0/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0 b/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0
deleted file mode 100644
index 9ac0e2a..0000000
Binary files a/ML/Pytorch/huggingface/lightning_logs/version_1/events.out.tfevents.1676993775.mrbeast.568809.0 and /dev/null differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml
deleted file mode 100644
index 0967ef4..0000000
--- a/ML/Pytorch/huggingface/lightning_logs/version_1/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0 b/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0
deleted file mode 100644
index 256eb1d..0000000
Binary files a/ML/Pytorch/huggingface/lightning_logs/version_2/events.out.tfevents.1676993814.mrbeast.570170.0 and /dev/null differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml
deleted file mode 100644
index 0967ef4..0000000
--- a/ML/Pytorch/huggingface/lightning_logs/version_2/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1 b/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1
deleted file mode 100644
index a60d00b..0000000
Binary files a/ML/Pytorch/huggingface/lightning_logs/version_3/events.out.tfevents.1676993905.mrbeast.570170.1 and /dev/null differ
diff --git a/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml b/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml
deleted file mode 100644
index 0967ef4..0000000
--- a/ML/Pytorch/huggingface/lightning_logs/version_3/hparams.yaml
+++ /dev/null
@@ -1 +0,0 @@
-{}
diff --git a/ML/Pytorch/huggingface/model.py b/ML/Pytorch/huggingface/model.py
deleted file mode 100644
index e7b2eb1..0000000
--- a/ML/Pytorch/huggingface/model.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import torch
-import pytorch_lightning as pl
-from datasets import load_dataset, load_metric
-from transformers import T5Config, T5ForConditionalGeneration
-
-from transformers import (
- AutoModel,
- AutoModelForSeq2SeqLM,
- AutoTokenizer,
- DataCollatorForSeq2Seq,
- Seq2SeqTrainingArguments,
- Seq2SeqTrainer,
-)
-
-
-class MyLightningModule(pl.LightningModule):
- def __init__(self, model_name, learning_rate, weight_decay):
- super().__init__()
- self.model_name = model_name
- self.learning_rate = learning_rate
- self.weight_decay = weight_decay
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
-
- # Load the pre-trained model and tokenizer
- #self.model = torch.compile(
- # AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
- #)
-
- # Create a T5-small configuration
- config = T5Config.from_pretrained("t5-small")
-
- # Initialize the T5 model with random weights
- self.model = torch.compile(T5ForConditionalGeneration(config))
-
- # Load the ROUGE metric
- self.metric = load_metric("rouge")
- self.logits = []
- self.labels = []
-
- def forward(self, input_ids, attention_mask, labels=None):
- output = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- labels=labels,
- )
- return output.loss, output.logits
-
- def training_step(self, batch, batch_idx):
- input_ids = batch["input_ids"]
- attention_mask = batch["attention_mask"]
- labels = batch["labels"]
- loss, logits = self(input_ids, attention_mask, labels)
- self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
- return {"loss": loss, "logits": logits}
-
- def validation_step(self, batch, batch_idx):
- input_ids = batch["input_ids"]
- attention_mask = batch["attention_mask"]
- labels = batch["labels"]
- loss, logits = self(input_ids, attention_mask, labels)
- self.log("val_loss", loss, on_epoch=True, on_step=False)
-
- # add logits and labels to instance attributes, but make sure to detach them
- # from the computational graph first
- self.logits.append(logits.argmax(dim=-1).detach().cpu())
- self.labels.append(labels.detach().cpu())
- return {"loss": loss, "logits": logits, "labels": labels}
-
- def on_validation_epoch_end(self):
- # Concatenate tensors in logits and labels lists
- pred_token_ids = torch.cat(self.logits, dim=0)
- true_labels = torch.cat(self.labels, dim=0)
-
- # Decode predictions and labels using the saved instance attributes
- decoded_preds = self.tokenizer.batch_decode(
- pred_token_ids, skip_special_tokens=True
- )
- decoded_labels = self.tokenizer.batch_decode(
- true_labels, skip_special_tokens=True
- )
-
- # Compute ROUGE scores
- scores = self.metric.compute(
- predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"]
- )["rouge1"].mid
-
- self.log("rouge1_precision", scores.precision, prog_bar=True)
- self.log("rouge1_recall", scores.recall, prog_bar=True)
- self.log("rouge1_fmeasure", scores.fmeasure, prog_bar=True)
-
- # Clear logits and labels instance attributes for the next validation epoch
- self.logits.clear()
- self.labels.clear()
-
- def predict(self, article: str, max_input_length: int = 512, max_output_length: int = 150) -> str:
- # Set the model to evaluation mode
- self.model.eval()
-
- # Tokenize the input article
- inputs = self.tokenizer(
- article,
- max_length=max_input_length,
- padding="max_length",
- truncation=True,
- return_tensors="pt"
- )
-
- # Move the input tensors to the same device as the model
- inputs = {key: value.to(self.device) for key, value in inputs.items()}
-
- # Generate summary
- with torch.no_grad():
- output = self.model.generate(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- max_length=max_output_length,
- num_return_sequences=1,
- )
-
- # Decode and return the summary
- summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
- return summary
-
- def configure_optimizers(self):
- optimizer = torch.optim.AdamW(
- self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
- )
- return optimizer
-
-
diff --git a/ML/Pytorch/huggingface/test.py b/ML/Pytorch/huggingface/test.py
deleted file mode 100644
index dd1f9d8..0000000
--- a/ML/Pytorch/huggingface/test.py
+++ /dev/null
@@ -1,2 +0,0 @@
-l = ["cat", "dog"]
-sentence = "The quick brown fox jumps over the lazy dog"
diff --git a/ML/Pytorch/huggingface/train.py b/ML/Pytorch/huggingface/train.py
deleted file mode 100644
index 6b5a91f..0000000
--- a/ML/Pytorch/huggingface/train.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from dataset import MyDataModule
-from model import MyLightningModule
-import pytorch_lightning as pl
-from pytorch_lightning import Trainer
-from pytorch_lightning.callbacks import ModelCheckpoint
-from pytorch_lightning.loggers import TensorBoardLogger
-from transformers import (
- AutoModel,
- AutoModelForSeq2SeqLM,
- AutoTokenizer,
- DataCollatorForSeq2Seq,
- Seq2SeqTrainingArguments,
- Seq2SeqTrainer,
-)
-import torch
-
-torch.set_float32_matmul_precision("medium")
-
-if __name__ == "__main__":
- # Define the checkpoint callback
- checkpoint_callback = ModelCheckpoint(
- monitor="val_loss",
- dirpath="checkpoints",
- filename="my_model-{epoch:02d}-{val_loss:.2f}",
- save_top_k=-1,
- every_n_epochs=1,
- verbose=True,
- )
- logger = TensorBoardLogger("tb_logs", name="t5_dailymail")
-
- model_name = "t5-small"
- tokenizer = AutoTokenizer.from_pretrained(model_name)
-
- # File paths
- train_csv = "train.csv"
- val_csv = "validation.csv"
- test_csv = "test.csv"
-
- # Create the data module
- dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=32)
- dm.setup()
-
- model = MyLightningModule(
- model_name="t5-small", learning_rate=1e-4, weight_decay=1e-5
- )
-
-
- #checkpoint_path = "checkpoints/curr.ckpt"
- #checkpoint = torch.load(checkpoint_path)
- #model.load_state_dict(checkpoint["state_dict"])
-
- trainer = pl.Trainer(
- accelerator="gpu",
- devices=[0, 1],
- max_epochs=10,
- precision=16,
- logger=logger,
- callbacks=[checkpoint_callback],
- log_every_n_steps=10,
- )
- trainer.fit(model, dm)
- trainer.validate(model, dm)
-
- #example = """Former President Donald Trump claims in a social media post that he will be arrested next week. The claim comes while a New York prosecutor considers charging Trump in connection with hush money paid to adult film actress Stormy Daniels but there has been no official announcement of any plans for an indictment. What we know about Trump possibly facing criminal indictment in New York City. Trump has been entangled in several criminal investigations but the case related to Daniels is the longest-running of all of them, reaching back to 2016. On his platform Truth Social on Saturday morning, Trump cited "illegal leaks" that he will be arrested Tuesday and he called for protests. Trump, who is running for president in 2024, also defended himself, saying that he has not committed a crime — though he did not disclose what he expects to be charged with — and he accused the Manhattan District Attorney's Office of being "corrupt & highly political.". 'I'M BACK!' Trump posts on Facebook, YouTube for first time in two years. The Manhattan District Attorney's Office declined to comment on whether it will soon be pursing an arrest warrant for Trump. But the Associated Press reported that law enforcement officials in New York are discussing security preparations in anticipation that Trump may be indicted in coming weeks. If it does occur, Trump would become the first former president to be indicted in U.S. history."""
- #print(len(tokenizer(example)["input_ids"]))
- #summary = model.predict(example)
- #print(summary)