mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-21 11:18:01 +00:00
remove some old stuff
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
{
|
||||
"cells": [],
|
||||
"metadata": {},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f54ecf0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"# HuggingFace Tutorial Series\n",
|
||||
"- 1. What is Huggingface?\n",
|
||||
"- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
|
||||
"- 3. Using the HuggingFace Pipeline (High level feature)\n",
|
||||
"- 4. How the pipeline works at a lower level\n",
|
||||
"- 5. HuggingFace Datasets\n",
|
||||
"- 6. HuggingFace Tokenizer\n",
|
||||
"- 7. HuggingFace Evaluate\n",
|
||||
"- 8. HuggingFace Trainer\n",
|
||||
"- 9. Putting it together to finetune a news article summarizer\n",
|
||||
"- 10. Making it more general and robust with Lightning and custom data loading\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"torch.set_float32_matmul_precision(\"medium\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "418cb03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class cnn_dailymail(Dataset):\n",
|
||||
" def __init__(self, csv_file, tokenizer, max_length=512):\n",
|
||||
" self.data = pd.read_csv(csv_file)\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" article = self.data.loc[idx, 'article']\n",
|
||||
" highlights = self.data.loc[idx, 'highlights']\n",
|
||||
"\n",
|
||||
" inputs = self.tokenizer(\n",
|
||||
" article,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
" targets = self.tokenizer(\n",
|
||||
" highlights,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" 'input_ids': inputs['input_ids'].squeeze(),\n",
|
||||
" 'attention_mask': inputs['attention_mask'].squeeze(),\n",
|
||||
" 'labels': targets['input_ids'].squeeze()\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaa62755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
|
||||
" super().__init__()\n",
|
||||
" self.train_csv = train_csv\n",
|
||||
" self.val_csv = val_csv\n",
|
||||
" self.test_csv = test_csv\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" if stage in ('fit', None):\n",
|
||||
" self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
|
||||
" self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
|
||||
" if stage in ('test', None):\n",
|
||||
" self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
|
||||
" \n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" \n",
|
||||
" # Save logits and labels as instance attributes\n",
|
||||
" if not hasattr(self, \"logits\"):\n",
|
||||
" self.logits = logits\n",
|
||||
" else:\n",
|
||||
" self.logits = torch.cat((self.logits, logits), dim=0)\n",
|
||||
" \n",
|
||||
" if not hasattr(self, \"labels\"):\n",
|
||||
" self.labels = labels\n",
|
||||
" else:\n",
|
||||
" self.labels = torch.cat((self.labels, labels), dim=0)\n",
|
||||
" \n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def on_validation_epoch_end(self):\n",
|
||||
" # Convert logits to predicted token IDs\n",
|
||||
" pred_token_ids = self.logits.argmax(dim=-1)\n",
|
||||
"\n",
|
||||
" # Decode predictions and labels using the saved instance attributes\n",
|
||||
" decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Compute ROUGE scores\n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
"\n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
"\n",
|
||||
" # Clear logits and labels instance attributes for the next validation epoch\n",
|
||||
" del self.logits\n",
|
||||
" del self.labels\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# File paths\n",
|
||||
"train_csv = \"train.csv\"\n",
|
||||
"val_csv = \"validation.csv\"\n",
|
||||
"test_csv = \"test.csv\"\n",
|
||||
"\n",
|
||||
"# Create the data module\n",
|
||||
"dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
|
||||
"dm.setup()\n",
|
||||
"\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5d3d684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0494596",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### next steps:\n",
|
||||
"* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
|
||||
"\n",
|
||||
"#### what we've done:\n",
|
||||
"* Change the data loading so it's more general, meaning on the fly loading from disk\n",
|
||||
"* add torch.compile\n",
|
||||
"* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
|
||||
"* add tensorboard visualization\n",
|
||||
"* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
|
||||
"* 2. Create an inference step, send in news article -> get summary, check that it works\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0f9b71ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,463 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04530b1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the LightningDataModule\n",
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" def prepare_data(self):\n",
|
||||
" # Download and preprocess the data\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
" \n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" # Load and preprocess the data\n",
|
||||
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
" self.train_ds = train_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.val_ds = val_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def preprocess_function(self, batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def validation_epoch_end(self, outputs):\n",
|
||||
" decoded_preds = []\n",
|
||||
" decoded_labels = []\n",
|
||||
" for output in outputs:\n",
|
||||
" logits = output['logits']\n",
|
||||
" labels = output['labels']\n",
|
||||
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
|
||||
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"\n",
|
||||
" 0%| | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
|
||||
" 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
|
||||
" 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
|
||||
" 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
|
||||
" 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
|
||||
" 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
|
||||
" 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
|
||||
" 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
|
||||
" 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
|
||||
" 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
|
||||
" 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
|
||||
" 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
|
||||
" 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
|
||||
" 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
|
||||
" 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
|
||||
" 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
|
||||
" 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
|
||||
" 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
|
||||
" 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
|
||||
" 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
|
||||
" 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
|
||||
" 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
|
||||
" 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
|
||||
" 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
|
||||
" 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
|
||||
" 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
|
||||
" 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
|
||||
" 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
|
||||
" 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
|
||||
" 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
|
||||
" 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
|
||||
" 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
|
||||
" 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
|
||||
" 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
|
||||
" 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
|
||||
" 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
|
||||
" 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
|
||||
" 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
|
||||
" 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
|
||||
" 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
|
||||
" 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
|
||||
" 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
|
||||
" 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
|
||||
" 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
|
||||
" 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
|
||||
" 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
|
||||
" 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
|
||||
" 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
|
||||
" 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
|
||||
" 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
|
||||
" 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
|
||||
" 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
|
||||
" 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
|
||||
" 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
|
||||
" 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
|
||||
" 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
|
||||
" 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
|
||||
" 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
|
||||
" 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
|
||||
" 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
|
||||
" 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
|
||||
" 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
|
||||
" 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
|
||||
" 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
|
||||
" 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
|
||||
" 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
|
||||
" 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
|
||||
" 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
|
||||
" 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
|
||||
" 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
|
||||
" 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
|
||||
" 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
|
||||
" 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
|
||||
" 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
|
||||
" 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
|
||||
" 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
|
||||
" 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
|
||||
" 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
|
||||
" 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
|
||||
" 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
|
||||
" 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
|
||||
" 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
|
||||
" 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
|
||||
" 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
|
||||
" 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
|
||||
" 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
|
||||
" 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
|
||||
" 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
|
||||
" 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
|
||||
" 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
|
||||
" 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
|
||||
" 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
|
||||
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
|
||||
" 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
|
||||
" 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
|
||||
" 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
|
||||
" 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
|
||||
" 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
|
||||
" 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
|
||||
" 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
|
||||
" 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
|
||||
" 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
|
||||
" 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
|
||||
" 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
|
||||
" 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
|
||||
" 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
|
||||
" 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
|
||||
" 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
|
||||
" 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
|
||||
" 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
|
||||
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
|
||||
" 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
|
||||
" 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
|
||||
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
|
||||
" 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
|
||||
" 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
|
||||
" 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
|
||||
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
|
||||
" 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
|
||||
" 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
|
||||
" 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
|
||||
"\n",
|
||||
" 0%| | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
|
||||
" 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
|
||||
" 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
|
||||
" 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"0 | model | T5ForConditionalGeneration | 60.5 M\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"60.5 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"60.5 M Total params\n",
|
||||
"242.026 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'list' object has no attribute 'size'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_float32_matmul_precision(\"medium\")\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
|
||||
"dm = MyDataModule(batch_size=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "1395d5d2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,644 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "7d5e92c6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
|
||||
"example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
|
||||
"\n",
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"print(ner_entity_results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "bf67ee76",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Turtle s Original Cara mel Chocolate P eca n luster\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"\n",
|
||||
"# Initialize the entity words list with an empty string\n",
|
||||
"entity_words = [\"\"]\n",
|
||||
"\n",
|
||||
"# Loop through each dictionary in the list and extract the entity word\n",
|
||||
"for result in ner_entity_results:\n",
|
||||
" if result[\"entity\"] == \"B-FOOD\":\n",
|
||||
" entity_words.append(result[\"word\"])\n",
|
||||
" elif result[\"entity\"] == \"I-FOOD\":\n",
|
||||
" entity_words[-1] += \" \" + result[\"word\"]\n",
|
||||
"\n",
|
||||
"# Remove any remaining ## symbols and extra spaces\n",
|
||||
"entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
|
||||
"\n",
|
||||
"# Join the entity words into a single string\n",
|
||||
"output = \" \".join(entity_words)\n",
|
||||
"\n",
|
||||
"print(output)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc8e5ea0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"print(torch.cuda.is_available())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8a1e039",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ad73024",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = pipeline(\"zero-shot-classification\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "04f7e02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier(\n",
|
||||
" \"This is a course about the Transformers library\",\n",
|
||||
" candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6fb246c2",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4e174f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# Define input text and pre-trained model checkpoint\n",
|
||||
"text = \"My name is wolfgang and I live in berlin\"\n",
|
||||
"checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
|
||||
"\n",
|
||||
"# Instantiate tokenizer and encode input text\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# Instantiate model and generate output\n",
|
||||
"model = AutoModel.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"print(outputs[0].shape)\n",
|
||||
"\n",
|
||||
"# Instantiate token classification model and generate predictions\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
|
||||
"print(predictions)\n",
|
||||
"print(model.config.id2label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8212bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# prepare input\n",
|
||||
"text = \"Replace me by any text you'd like.\"\n",
|
||||
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"# forward pass\n",
|
||||
"output = model(**encoded_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "314cba41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"# Load the pre-trained tokenizer and model\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# Define the input sentence with a masked token\n",
|
||||
"text = \"I want to <mask> a new car tomorrow.\"\n",
|
||||
"\n",
|
||||
"# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
|
||||
"encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"print(output.logits.shape)\n",
|
||||
"print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
|
||||
"\n",
|
||||
"# Extract the predicted probabilities for the masked token\n",
|
||||
"predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
|
||||
"predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
|
||||
"\n",
|
||||
"# Get the top-k most probable predictions for the masked token\n",
|
||||
"k = 5\n",
|
||||
"top_k = torch.topk(predicted_probabilities, k)\n",
|
||||
"for i in range(k):\n",
|
||||
" token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
|
||||
" score = top_k.values[i].item()\n",
|
||||
" print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6187e77e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
|
||||
"\n",
|
||||
"sequences = [\n",
|
||||
" \"Using a Transformer network is simple\",\n",
|
||||
" \"The quick brown fox jumps over the lazy dog\",\n",
|
||||
" \"To be or not to be, that is the question\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
|
||||
"inputs = tokenizer(\n",
|
||||
" sequences,\n",
|
||||
" padding=True,\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the resulting input IDs and attention masks\n",
|
||||
"print(inputs['input_ids'])\n",
|
||||
"print(inputs['attention_mask'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc259c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43466db6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Huggingface:\n",
|
||||
"\n",
|
||||
"1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
|
||||
"2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
|
||||
"3. Load dataset\n",
|
||||
"4. How to finetune\n",
|
||||
"5. How to evaluate\n",
|
||||
"6. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97c474f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ed5d8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
|
||||
"\n",
|
||||
"# Same as before\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
||||
"sequences = [\n",
|
||||
" \"I've been waiting for a HuggingFace course my whole life.\",\n",
|
||||
" \"This course is amazing!\",\n",
|
||||
"]\n",
|
||||
"batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# This is new\n",
|
||||
"batch[\"labels\"] = torch.tensor([1, 1])\n",
|
||||
"\n",
|
||||
"optimizer = AdamW(model.parameters())\n",
|
||||
"loss = model(**batch).loss\n",
|
||||
"loss.backward()\n",
|
||||
"optimizer.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c598624f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd296227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_dataset = raw_datasets[\"train\"]\n",
|
||||
"raw_train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e462947a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
|
||||
"\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"\n",
|
||||
"def tokenize_function(example):\n",
|
||||
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
|
||||
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import evaluate\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e2795dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3af29cd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "817f644e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42819a6c",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb5986b0",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"batch_size=32\n",
|
||||
"\n",
|
||||
"# Define the generator function to preprocess the data in batches\n",
|
||||
"def preprocess_generator(examples):\n",
|
||||
" for i in range(0, len(examples[\"article\"]), batch_size):\n",
|
||||
" batch = examples[\"article\"][i:i+batch_size]\n",
|
||||
" targets = examples[\"highlights\"][i:i+batch_size]\n",
|
||||
" model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" yield model_inputs\n",
|
||||
"\n",
|
||||
"def preprocess_function(examples):\n",
|
||||
" articles = [ex for ex in examples[\"article\"]]\n",
|
||||
" summaries = [ex for ex in examples[\"highlights\"]]\n",
|
||||
"\n",
|
||||
" model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" \n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" return model_inputs\n",
|
||||
" \n",
|
||||
"# Load the dataset\n",
|
||||
"raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
|
||||
"preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
|
||||
"\n",
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" evaluation_strategy = \"epoch\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=1000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_ds,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7d62583e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d310a7b3",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessed_datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99d422cc",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" \n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=preprocessed_datasets[\"train\"],\n",
|
||||
" eval_dataset=preprocessed_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a5e97b57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install nltk\n",
|
||||
"!pip install rouge_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "558c3e66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Goal:\n",
|
||||
"\n",
|
||||
"1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
|
||||
"* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
|
||||
"\n",
|
||||
"2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "624d49ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,317 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f54ecf0b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\"\"\"\n",
|
||||
"# HuggingFace Tutorial Series\n",
|
||||
"- 1. What is Huggingface?\n",
|
||||
"- 2. Common tasks we can do with HuggingFace & explain the tasks briefly, like what is question answering etc\n",
|
||||
"- 3. Using the HuggingFace Pipeline (High level feature)\n",
|
||||
"- 4. How the pipeline works at a lower level\n",
|
||||
"- 5. HuggingFace Datasets\n",
|
||||
"- 6. HuggingFace Tokenizer\n",
|
||||
"- 7. HuggingFace Evaluate\n",
|
||||
"- 8. HuggingFace Trainer\n",
|
||||
"- 9. Putting it together to finetune a news article summarizer\n",
|
||||
"- 10. Making it more general and robust with Lightning and custom data loading\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"from torch.utils.data import Dataset\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"torch.set_float32_matmul_precision(\"medium\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "418cb03a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class cnn_dailymail(Dataset):\n",
|
||||
" def __init__(self, csv_file, tokenizer, max_length=512):\n",
|
||||
" self.data = pd.read_csv(csv_file)\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" article = self.data.loc[idx, 'article']\n",
|
||||
" highlights = self.data.loc[idx, 'highlights']\n",
|
||||
"\n",
|
||||
" inputs = self.tokenizer(\n",
|
||||
" article,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
" targets = self.tokenizer(\n",
|
||||
" highlights,\n",
|
||||
" truncation=True,\n",
|
||||
" padding='max_length',\n",
|
||||
" max_length=self.max_length,\n",
|
||||
" return_tensors='pt'\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return {\n",
|
||||
" 'input_ids': inputs['input_ids'].squeeze(),\n",
|
||||
" 'attention_mask': inputs['attention_mask'].squeeze(),\n",
|
||||
" 'labels': targets['input_ids'].squeeze()\n",
|
||||
" }"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "aaa62755",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512):\n",
|
||||
" super().__init__()\n",
|
||||
" self.train_csv = train_csv\n",
|
||||
" self.val_csv = val_csv\n",
|
||||
" self.test_csv = test_csv\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.max_length = max_length\n",
|
||||
"\n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" if stage in ('fit', None):\n",
|
||||
" self.train_dataset = cnn_dailymail(self.train_csv, self.tokenizer, self.max_length)\n",
|
||||
" self.val_dataset = cnn_dailymail(self.val_csv, self.tokenizer, self.max_length)\n",
|
||||
" if stage in ('test', None):\n",
|
||||
" self.test_dataset = cnn_dailymail(self.test_csv, self.tokenizer, self.max_length)\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n",
|
||||
"\n",
|
||||
" def test_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = torch.compile(AutoModelForSeq2SeqLM.from_pretrained(self.model_name))\n",
|
||||
" \n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" \n",
|
||||
" # Save logits and labels as instance attributes\n",
|
||||
" if not hasattr(self, \"logits\"):\n",
|
||||
" self.logits = logits\n",
|
||||
" else:\n",
|
||||
" self.logits = torch.cat((self.logits, logits), dim=0)\n",
|
||||
" \n",
|
||||
" if not hasattr(self, \"labels\"):\n",
|
||||
" self.labels = labels\n",
|
||||
" else:\n",
|
||||
" self.labels = torch.cat((self.labels, labels), dim=0)\n",
|
||||
" \n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def on_validation_epoch_end(self):\n",
|
||||
" # Convert logits to predicted token IDs\n",
|
||||
" pred_token_ids = self.logits.argmax(dim=-1)\n",
|
||||
"\n",
|
||||
" # Decode predictions and labels using the saved instance attributes\n",
|
||||
" decoded_preds = tokenizer.batch_decode(pred_token_ids, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(self.labels, skip_special_tokens=True)\n",
|
||||
"\n",
|
||||
" # Compute ROUGE scores\n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
"\n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
"\n",
|
||||
" # Clear logits and labels instance attributes for the next validation epoch\n",
|
||||
" del self.logits\n",
|
||||
" del self.labels\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# File paths\n",
|
||||
"train_csv = \"train.csv\"\n",
|
||||
"val_csv = \"validation.csv\"\n",
|
||||
"test_csv = \"test.csv\"\n",
|
||||
"\n",
|
||||
"# Create the data module\n",
|
||||
"dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=16)\n",
|
||||
"dm.setup()\n",
|
||||
"\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-4, weight_decay=1e-5)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=1, precision=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b5d3d684",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"http://localhost:18888/notebooks/cnndaily_t5_lightning_customdataloading.ipynb"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a0494596",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### next steps:\n",
|
||||
"* if article is > 512, because now we are truncating maybe it causes issues if the article is much longer?\n",
|
||||
"\n",
|
||||
"#### what we've done:\n",
|
||||
"* Change the data loading so it's more general, meaning on the fly loading from disk\n",
|
||||
"* add torch.compile\n",
|
||||
"* 1. Clean up the code, make it into scripts instead of notebook -> Train for an epoch (add multi-gpu training?)\n",
|
||||
"* add tensorboard visualization\n",
|
||||
"* not use pretrained weights but from scratch to ensure that training setup works and actually improving\n",
|
||||
"* 2. Create an inference step, send in news article -> get summary, check that it works\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0f9b71ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
import pandas as pd
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
|
||||
|
||||
class cnn_dailymail(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=512):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
|
||||
# if the csv_file is "train.csv" then only take out 10% of the data. make sure to reset indices etc
|
||||
#if csv_file == "train.csv":
|
||||
# self.data = self.data.sample(frac=0.05, random_state=42).reset_index(drop=True)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
article = self.data.loc[idx, "article"]
|
||||
highlights = self.data.loc[idx, "highlights"]
|
||||
|
||||
inputs = self.tokenizer(
|
||||
article,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
targets = self.tokenizer(
|
||||
highlights,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return {
|
||||
"input_ids": inputs["input_ids"].squeeze(),
|
||||
"attention_mask": inputs["attention_mask"].squeeze(),
|
||||
"labels": targets["input_ids"].squeeze(),
|
||||
}
|
||||
|
||||
|
||||
class MyDataModule(pl.LightningDataModule):
|
||||
def __init__(
|
||||
self, train_csv, val_csv, test_csv, tokenizer, batch_size=16, max_length=512
|
||||
):
|
||||
super().__init__()
|
||||
self.train_csv = train_csv
|
||||
self.val_csv = val_csv
|
||||
self.test_csv = test_csv
|
||||
self.tokenizer = tokenizer
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
|
||||
def setup(self, stage=None):
|
||||
if stage in ("fit", None):
|
||||
self.train_dataset = cnn_dailymail(
|
||||
self.train_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
self.val_dataset = cnn_dailymail(
|
||||
self.val_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
if stage in ("test", None):
|
||||
self.test_dataset = cnn_dailymail(
|
||||
self.test_csv, self.tokenizer, self.max_length
|
||||
)
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.batch_size,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
num_workers=6,
|
||||
)
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2
|
||||
)
|
||||
|
||||
def test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(
|
||||
self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1
|
||||
)
|
||||
@@ -1,470 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "ec1aae37",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2023-02-21 16:36:20.707209: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
||||
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
||||
"2023-02-21 16:36:21.233575: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233623: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
||||
"2023-02-21 16:36:21.233628: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import warnings\n",
|
||||
"warnings.simplefilter(\"ignore\")\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "5fd7cb0c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "04530b1e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Define the LightningDataModule\n",
|
||||
"class MyDataModule(pl.LightningDataModule):\n",
|
||||
" def __init__(self, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" def prepare_data(self):\n",
|
||||
" # Download and preprocess the data\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
" \n",
|
||||
" def setup(self, stage=None):\n",
|
||||
" # Load and preprocess the data\n",
|
||||
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train[:10%]\")\n",
|
||||
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
" self.train_ds = train_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.val_ds = val_data.map(\n",
|
||||
" self.preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=self.batch_size,\n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def preprocess_function(self, batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
" def train_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
|
||||
"\n",
|
||||
" def val_dataloader(self):\n",
|
||||
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "fbb699e1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" # Load the ROUGE metric\n",
|
||||
" self.metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('train_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits}\n",
|
||||
" \n",
|
||||
" def validation_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" loss, logits = self(input_ids, attention_mask, labels)\n",
|
||||
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
|
||||
" return {'loss': loss, 'logits': logits, \"labels\":labels}\n",
|
||||
" \n",
|
||||
" def validation_epoch_end(self, outputs):\n",
|
||||
" decoded_preds = []\n",
|
||||
" decoded_labels = []\n",
|
||||
" for output in outputs:\n",
|
||||
" logits = output['logits']\n",
|
||||
" labels = output['labels']\n",
|
||||
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
|
||||
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
|
||||
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
|
||||
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
|
||||
" \n",
|
||||
" def configure_optimizers(self):\n",
|
||||
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
|
||||
" return optimizer\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "dd63c628",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: True (cuda), used: True\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"IPU available: False, using: 0 IPUs\n",
|
||||
"HPU available: False, using: 0 HPUs\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n",
|
||||
"\n",
|
||||
" 0%| | 0/1795 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 1%|▉ | 13/1795 [00:00<00:14, 121.44ba/s]\u001b[A\n",
|
||||
" 1%|█▉ | 26/1795 [00:00<00:15, 117.31ba/s]\u001b[A\n",
|
||||
" 2%|██▊ | 38/1795 [00:00<00:15, 114.50ba/s]\u001b[A\n",
|
||||
" 3%|███▋ | 50/1795 [00:00<00:15, 114.43ba/s]\u001b[A\n",
|
||||
" 3%|████▌ | 62/1795 [00:00<00:15, 115.53ba/s]\u001b[A\n",
|
||||
" 4%|█████▍ | 74/1795 [00:00<00:15, 113.50ba/s]\u001b[A\n",
|
||||
" 5%|██████▎ | 86/1795 [00:00<00:15, 111.92ba/s]\u001b[A\n",
|
||||
" 5%|███████▎ | 98/1795 [00:00<00:15, 111.38ba/s]\u001b[A\n",
|
||||
" 6%|████████ | 110/1795 [00:00<00:15, 112.08ba/s]\u001b[A\n",
|
||||
" 7%|████████▉ | 122/1795 [00:01<00:14, 113.73ba/s]\u001b[A\n",
|
||||
" 7%|█████████▊ | 134/1795 [00:01<00:14, 113.43ba/s]\u001b[A\n",
|
||||
" 8%|██████████▋ | 146/1795 [00:01<00:14, 111.37ba/s]\u001b[A\n",
|
||||
" 9%|███████████▌ | 158/1795 [00:01<00:14, 111.32ba/s]\u001b[A\n",
|
||||
" 9%|████████████▌ | 170/1795 [00:01<00:14, 110.29ba/s]\u001b[A\n",
|
||||
" 10%|█████████████▍ | 182/1795 [00:01<00:14, 110.06ba/s]\u001b[A\n",
|
||||
" 11%|██████████████▎ | 194/1795 [00:01<00:14, 111.06ba/s]\u001b[A\n",
|
||||
" 11%|███████████████▏ | 206/1795 [00:01<00:14, 111.15ba/s]\u001b[A\n",
|
||||
" 12%|████████████████ | 218/1795 [00:01<00:14, 110.27ba/s]\u001b[A\n",
|
||||
" 13%|████████████████▉ | 230/1795 [00:02<00:14, 109.17ba/s]\u001b[A\n",
|
||||
" 13%|█████████████████▋ | 241/1795 [00:02<00:14, 107.81ba/s]\u001b[A\n",
|
||||
" 14%|██████████████████▌ | 252/1795 [00:02<00:14, 107.84ba/s]\u001b[A\n",
|
||||
" 15%|███████████████████▎ | 263/1795 [00:02<00:14, 107.73ba/s]\u001b[A\n",
|
||||
" 15%|████████████████████▏ | 274/1795 [00:02<00:14, 107.06ba/s]\u001b[A\n",
|
||||
" 16%|█████████████████████ | 286/1795 [00:02<00:13, 108.37ba/s]\u001b[A\n",
|
||||
" 17%|█████████████████████▊ | 297/1795 [00:02<00:13, 107.89ba/s]\u001b[A\n",
|
||||
" 17%|██████████████████████▋ | 309/1795 [00:02<00:13, 108.63ba/s]\u001b[A\n",
|
||||
" 18%|███████████████████████▌ | 320/1795 [00:02<00:13, 106.85ba/s]\u001b[A\n",
|
||||
" 18%|████████████████████████▎ | 331/1795 [00:03<00:13, 105.16ba/s]\u001b[A\n",
|
||||
" 19%|█████████████████████████▏ | 342/1795 [00:03<00:13, 105.20ba/s]\u001b[A\n",
|
||||
" 20%|█████████████████████████▉ | 353/1795 [00:03<00:13, 106.52ba/s]\u001b[A\n",
|
||||
" 20%|██████████████████████████▊ | 364/1795 [00:03<00:13, 106.07ba/s]\u001b[A\n",
|
||||
" 21%|███████████████████████████▌ | 375/1795 [00:03<00:13, 106.21ba/s]\u001b[A\n",
|
||||
" 22%|████████████████████████████▍ | 386/1795 [00:03<00:13, 106.57ba/s]\u001b[A\n",
|
||||
" 22%|█████████████████████████████▎ | 398/1795 [00:03<00:12, 108.52ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████ | 409/1795 [00:03<00:12, 108.42ba/s]\u001b[A\n",
|
||||
" 23%|██████████████████████████████▉ | 421/1795 [00:03<00:12, 110.30ba/s]\u001b[A\n",
|
||||
" 24%|███████████████████████████████▊ | 433/1795 [00:03<00:12, 108.73ba/s]\u001b[A\n",
|
||||
" 25%|████████████████████████████████▋ | 444/1795 [00:04<00:12, 106.43ba/s]\u001b[A\n",
|
||||
" 25%|█████████████████████████████████▍ | 455/1795 [00:04<00:12, 106.82ba/s]\u001b[A\n",
|
||||
" 26%|██████████████████████████████████▎ | 466/1795 [00:04<00:12, 105.85ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████ | 477/1795 [00:04<00:12, 107.02ba/s]\u001b[A\n",
|
||||
" 27%|███████████████████████████████████▉ | 488/1795 [00:04<00:12, 106.66ba/s]\u001b[A\n",
|
||||
" 28%|████████████████████████████████████▊ | 500/1795 [00:04<00:11, 108.59ba/s]\u001b[A\n",
|
||||
" 28%|█████████████████████████████████████▌ | 511/1795 [00:04<00:12, 106.49ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▍ | 523/1795 [00:04<00:11, 109.26ba/s]\u001b[A\n",
|
||||
" 30%|███████████████████████████████████████▎ | 535/1795 [00:04<00:11, 109.78ba/s]\u001b[A\n",
|
||||
" 30%|████████████████████████████████████████▏ | 546/1795 [00:04<00:11, 108.30ba/s]\u001b[A\n",
|
||||
" 31%|████████████████████████████████████████▉ | 557/1795 [00:05<00:11, 107.77ba/s]\u001b[A\n",
|
||||
" 32%|█████████████████████████████████████████▊ | 569/1795 [00:05<00:11, 108.36ba/s]\u001b[A\n",
|
||||
" 32%|██████████████████████████████████████████▋ | 580/1795 [00:05<00:11, 107.05ba/s]\u001b[A\n",
|
||||
" 33%|███████████████████████████████████████████▌ | 592/1795 [00:05<00:11, 108.48ba/s]\u001b[A\n",
|
||||
" 34%|████████████████████████████████████████████▎ | 603/1795 [00:05<00:11, 108.25ba/s]\u001b[A\n",
|
||||
" 34%|█████████████████████████████████████████████▏ | 615/1795 [00:05<00:10, 110.59ba/s]\u001b[A\n",
|
||||
" 35%|██████████████████████████████████████████████ | 627/1795 [00:05<00:10, 111.44ba/s]\u001b[A\n",
|
||||
" 36%|██████████████████████████████████████████████▉ | 639/1795 [00:05<00:10, 109.07ba/s]\u001b[A\n",
|
||||
" 36%|███████████████████████████████████████████████▊ | 651/1795 [00:05<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 37%|████████████████████████████████████████████████▋ | 662/1795 [00:06<00:10, 109.69ba/s]\u001b[A\n",
|
||||
" 37%|█████████████████████████████████████████████████▍ | 673/1795 [00:06<00:10, 109.08ba/s]\u001b[A\n",
|
||||
" 38%|██████████████████████████████████████████████████▎ | 685/1795 [00:06<00:10, 109.77ba/s]\u001b[A\n",
|
||||
" 39%|███████████████████████████████████████████████████▎ | 697/1795 [00:06<00:10, 109.54ba/s]\u001b[A\n",
|
||||
" 39%|████████████████████████████████████████████████████ | 708/1795 [00:06<00:09, 109.08ba/s]\u001b[A\n",
|
||||
" 40%|████████████████████████████████████████████████████▉ | 720/1795 [00:06<00:09, 110.53ba/s]\u001b[A\n",
|
||||
" 41%|█████████████████████████████████████████████████████▊ | 732/1795 [00:06<00:09, 108.30ba/s]\u001b[A\n",
|
||||
" 41%|██████████████████████████████████████████████████████▋ | 744/1795 [00:06<00:09, 110.04ba/s]\u001b[A\n",
|
||||
" 42%|███████████████████████████████████████████████████████▌ | 756/1795 [00:06<00:09, 112.10ba/s]\u001b[A\n",
|
||||
" 43%|████████████████████████████████████████████████████████▍ | 768/1795 [00:07<00:09, 111.21ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▎ | 780/1795 [00:07<00:09, 111.99ba/s]\u001b[A\n",
|
||||
" 44%|██████████████████████████████████████████████████████████▏ | 792/1795 [00:07<00:08, 112.21ba/s]\u001b[A\n",
|
||||
" 45%|███████████████████████████████████████████████████████████ | 804/1795 [00:07<00:09, 109.31ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████ | 817/1795 [00:07<00:08, 113.17ba/s]\u001b[A\n",
|
||||
" 46%|████████████████████████████████████████████████████████████▉ | 829/1795 [00:07<00:08, 113.26ba/s]\u001b[A\n",
|
||||
" 47%|█████████████████████████████████████████████████████████████▊ | 841/1795 [00:07<00:08, 113.69ba/s]\u001b[A\n",
|
||||
" 48%|██████████████████████████████████████████████████████████████▋ | 853/1795 [00:07<00:08, 114.08ba/s]\u001b[A\n",
|
||||
" 48%|███████████████████████████████████████████████████████████████▌ | 865/1795 [00:07<00:08, 112.82ba/s]\u001b[A\n",
|
||||
" 49%|████████████████████████████████████████████████████████████████▍ | 877/1795 [00:07<00:08, 113.22ba/s]\u001b[A\n",
|
||||
" 50%|█████████████████████████████████████████████████████████████████▍ | 890/1795 [00:08<00:07, 115.71ba/s]\u001b[A\n",
|
||||
" 50%|██████████████████████████████████████████████████████████████████▎ | 902/1795 [00:08<00:07, 115.77ba/s]\u001b[A\n",
|
||||
" 51%|███████████████████████████████████████████████████████████████████▏ | 914/1795 [00:08<00:07, 114.07ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████ | 926/1795 [00:08<00:07, 114.19ba/s]\u001b[A\n",
|
||||
" 52%|████████████████████████████████████████████████████████████████████▉ | 938/1795 [00:08<00:07, 115.57ba/s]\u001b[A\n",
|
||||
" 53%|█████████████████████████████████████████████████████████████████████▊ | 950/1795 [00:08<00:07, 115.94ba/s]\u001b[A\n",
|
||||
" 54%|██████████████████████████████████████████████████████████████████████▋ | 962/1795 [00:08<00:07, 116.65ba/s]\u001b[A\n",
|
||||
" 54%|███████████████████████████████████████████████████████████████████████▋ | 974/1795 [00:08<00:07, 113.94ba/s]\u001b[A\n",
|
||||
" 55%|████████████████████████████████████████████████████████████████████████▌ | 986/1795 [00:08<00:07, 111.71ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▍ | 998/1795 [00:09<00:07, 107.78ba/s]\u001b[A\n",
|
||||
" 56%|█████████████████████████████████████████████████████████████████████████▋ | 1009/1795 [00:09<00:07, 105.28ba/s]\u001b[A\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" 57%|██████████████████████████████████████████████████████████████████████████▌ | 1021/1795 [00:09<00:07, 107.16ba/s]\u001b[A\n",
|
||||
" 57%|███████████████████████████████████████████████████████████████████████████▎ | 1032/1795 [00:09<00:07, 107.83ba/s]\u001b[A\n",
|
||||
" 58%|████████████████████████████████████████████████████████████████████████████▏ | 1044/1795 [00:09<00:06, 109.92ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████ | 1056/1795 [00:09<00:06, 112.47ba/s]\u001b[A\n",
|
||||
" 59%|█████████████████████████████████████████████████████████████████████████████▉ | 1068/1795 [00:09<00:06, 113.56ba/s]\u001b[A\n",
|
||||
" 60%|██████████████████████████████████████████████████████████████████████████████▊ | 1080/1795 [00:09<00:06, 111.84ba/s]\u001b[A\n",
|
||||
" 61%|███████████████████████████████████████████████████████████████████████████████▋ | 1092/1795 [00:09<00:06, 111.27ba/s]\u001b[A\n",
|
||||
" 62%|████████████████████████████████████████████████████████████████████████████████▌ | 1104/1795 [00:10<00:06, 110.39ba/s]\u001b[A\n",
|
||||
" 62%|█████████████████████████████████████████████████████████████████████████████████▍ | 1116/1795 [00:10<00:06, 111.33ba/s]\u001b[A\n",
|
||||
" 63%|██████████████████████████████████████████████████████████████████████████████████▎ | 1128/1795 [00:10<00:05, 111.32ba/s]\u001b[A\n",
|
||||
" 64%|███████████████████████████████████████████████████████████████████████████████████▏ | 1140/1795 [00:10<00:05, 112.20ba/s]\u001b[A\n",
|
||||
" 64%|████████████████████████████████████████████████████████████████████████████████████▏ | 1153/1795 [00:10<00:05, 115.15ba/s]\u001b[A\n",
|
||||
" 65%|█████████████████████████████████████████████████████████████████████████████████████ | 1165/1795 [00:10<00:05, 114.07ba/s]\u001b[A\n",
|
||||
" 66%|█████████████████████████████████████████████████████████████████████████████████████▉ | 1177/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 66%|██████████████████████████████████████████████████████████████████████████████████████▊ | 1189/1795 [00:10<00:05, 110.61ba/s]\u001b[A\n",
|
||||
" 67%|███████████████████████████████████████████████████████████████████████████████████████▋ | 1201/1795 [00:10<00:05, 112.56ba/s]\u001b[A\n",
|
||||
" 68%|████████████████████████████████████████████████████████████████████████████████████████▌ | 1213/1795 [00:10<00:05, 112.74ba/s]\u001b[A\n",
|
||||
" 68%|█████████████████████████████████████████████████████████████████████████████████████████▍ | 1225/1795 [00:11<00:05, 111.53ba/s]\u001b[A\n",
|
||||
" 69%|██████████████████████████████████████████████████████████████████████████████████████████▎ | 1237/1795 [00:11<00:05, 110.36ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▏ | 1249/1795 [00:11<00:04, 109.75ba/s]\u001b[A\n",
|
||||
" 70%|███████████████████████████████████████████████████████████████████████████████████████████▉ | 1260/1795 [00:11<00:04, 107.40ba/s]\u001b[A\n",
|
||||
" 71%|████████████████████████████████████████████████████████████████████████████████████████████▊ | 1271/1795 [00:11<00:04, 106.67ba/s]\u001b[A\n",
|
||||
" 71%|█████████████████████████████████████████████████████████████████████████████████████████████▌ | 1282/1795 [00:11<00:04, 106.95ba/s]\u001b[A\n",
|
||||
" 72%|██████████████████████████████████████████████████████████████████████████████████████████████▎ | 1293/1795 [00:11<00:04, 107.69ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▏ | 1304/1795 [00:11<00:04, 107.86ba/s]\u001b[A\n",
|
||||
" 73%|███████████████████████████████████████████████████████████████████████████████████████████████▉ | 1315/1795 [00:11<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1326/1795 [00:12<00:04, 107.71ba/s]\u001b[A\n",
|
||||
" 74%|█████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1337/1795 [00:12<00:04, 108.29ba/s]\u001b[A\n",
|
||||
" 75%|██████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1349/1795 [00:12<00:04, 109.37ba/s]\u001b[A\n",
|
||||
" 76%|███████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1361/1795 [00:12<00:03, 110.19ba/s]\u001b[A\n",
|
||||
" 76%|████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1373/1795 [00:12<00:03, 110.42ba/s]\u001b[A\n",
|
||||
" 77%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1385/1795 [00:12<00:03, 111.32ba/s]\u001b[A\n",
|
||||
" 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1397/1795 [00:12<00:03, 112.54ba/s]\u001b[A\n",
|
||||
" 78%|██████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1409/1795 [00:12<00:03, 112.91ba/s]\u001b[A\n",
|
||||
" 79%|███████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1421/1795 [00:12<00:03, 111.93ba/s]\u001b[A\n",
|
||||
" 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1433/1795 [00:12<00:03, 109.91ba/s]\u001b[A\n",
|
||||
" 81%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1445/1795 [00:13<00:03, 109.29ba/s]\u001b[A\n",
|
||||
" 81%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1456/1795 [00:13<00:03, 107.81ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1467/1795 [00:13<00:03, 107.59ba/s]\u001b[A\n",
|
||||
" 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1479/1795 [00:13<00:02, 107.83ba/s]\u001b[A\n",
|
||||
" 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1491/1795 [00:13<00:02, 108.92ba/s]\u001b[A\n",
|
||||
" 84%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1502/1795 [00:13<00:02, 108.64ba/s]\u001b[A\n",
|
||||
" 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1514/1795 [00:13<00:02, 110.24ba/s]\u001b[A\n",
|
||||
" 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1526/1795 [00:13<00:02, 111.64ba/s]\u001b[A\n",
|
||||
" 86%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1538/1795 [00:13<00:02, 110.08ba/s]\u001b[A\n",
|
||||
" 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1550/1795 [00:14<00:02, 108.01ba/s]\u001b[A\n",
|
||||
" 87%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1562/1795 [00:14<00:02, 109.96ba/s]\u001b[A\n",
|
||||
" 88%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1574/1795 [00:14<00:02, 109.67ba/s]\u001b[A\n",
|
||||
" 88%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1585/1795 [00:14<00:01, 107.92ba/s]\u001b[A\n",
|
||||
" 89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1596/1795 [00:14<00:01, 108.38ba/s]\u001b[A\n",
|
||||
" 90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1609/1795 [00:14<00:01, 112.44ba/s]\u001b[A\n",
|
||||
" 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1621/1795 [00:14<00:01, 110.29ba/s]\u001b[A\n",
|
||||
" 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1633/1795 [00:14<00:01, 110.18ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1645/1795 [00:14<00:01, 108.21ba/s]\u001b[A\n",
|
||||
" 92%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1656/1795 [00:15<00:01, 107.62ba/s]\u001b[A\n",
|
||||
" 93%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 1667/1795 [00:15<00:01, 106.66ba/s]\u001b[A\n",
|
||||
" 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1678/1795 [00:15<00:01, 104.97ba/s]\u001b[A\n",
|
||||
" 94%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1689/1795 [00:15<00:01, 105.67ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1700/1795 [00:15<00:00, 106.08ba/s]\u001b[A\n",
|
||||
" 95%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1712/1795 [00:15<00:00, 107.07ba/s]\u001b[A\n",
|
||||
" 96%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 1724/1795 [00:15<00:00, 108.53ba/s]\u001b[A\n",
|
||||
" 97%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | 1735/1795 [00:15<00:00, 108.05ba/s]\u001b[A\n",
|
||||
" 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍ | 1747/1795 [00:15<00:00, 110.64ba/s]\u001b[A\n",
|
||||
" 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 1759/1795 [00:15<00:00, 111.38ba/s]\u001b[A\n",
|
||||
" 99%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 1771/1795 [00:16<00:00, 110.67ba/s]\u001b[A\n",
|
||||
" 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 1783/1795 [00:16<00:00, 110.52ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1795/1795 [00:16<00:00, 109.98ba/s]\u001b[A\n",
|
||||
"\n",
|
||||
" 0%| | 0/84 [00:00<?, ?ba/s]\u001b[A\n",
|
||||
" 14%|███████████████████▎ | 12/84 [00:00<00:00, 110.99ba/s]\u001b[A\n",
|
||||
" 29%|██████████████████████████████████████▌ | 24/84 [00:00<00:00, 110.80ba/s]\u001b[A\n",
|
||||
" 43%|█████████████████████████████████████████████████████████▊ | 36/84 [00:00<00:00, 107.75ba/s]\u001b[A\n",
|
||||
" 56%|███████████████████████████████████████████████████████████████████████████▌ | 47/84 [00:00<00:00, 103.83ba/s]\u001b[A\n",
|
||||
" 69%|█████████████████████████████████████████████████████████████████████████████████████████████▏ | 58/84 [00:00<00:00, 102.87ba/s]\u001b[A\n",
|
||||
" 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 69/84 [00:00<00:00, 104.54ba/s]\u001b[A\n",
|
||||
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 106.09ba/s]\u001b[A\n",
|
||||
"LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]\n",
|
||||
"\n",
|
||||
" | Name | Type | Params\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"0 | model | T5ForConditionalGeneration | 60.5 M\n",
|
||||
"-----------------------------------------------------\n",
|
||||
"60.5 M Trainable params\n",
|
||||
"0 Non-trainable params\n",
|
||||
"60.5 M Total params\n",
|
||||
"242.026 Total estimated model params size (MB)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "AttributeError",
|
||||
"evalue": "'list' object has no attribute 'size'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[8], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(accelerator\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mgpu\u001b[39m\u001b[38;5;124m\"\u001b[39m, devices\u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m)\n\u001b[1;32m 4\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdm\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:608\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 606\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Trainer.fit()` requires a `LightningModule`, got: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mmodel\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__qualname__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 607\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m_lightning_module \u001b[38;5;241m=\u001b[39m model\n\u001b[0;32m--> 608\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 609\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 610\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:38\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 41\u001b[0m trainer\u001b[38;5;241m.\u001b[39m_call_teardown_hook()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:650\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 643\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m ckpt_path \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mresume_from_checkpoint\n\u001b[1;32m 644\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_set_ckpt_path(\n\u001b[1;32m 645\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 646\u001b[0m ckpt_path, \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[1;32m 647\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 648\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 649\u001b[0m )\n\u001b[0;32m--> 650\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 652\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 653\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1103\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 1099\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mrestore_training_state()\n\u001b[1;32m 1101\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39mresume_end()\n\u001b[0;32m-> 1103\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1105\u001b[0m log\u001b[38;5;241m.\u001b[39mdetail(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1106\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_teardown()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1182\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1180\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpredicting:\n\u001b[1;32m 1181\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_predict()\n\u001b[0;32m-> 1182\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_train\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1195\u001b[0m, in \u001b[0;36mTrainer._run_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pre_training_routine()\n\u001b[1;32m 1194\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m isolate_rng():\n\u001b[0;32m-> 1195\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_sanity_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1197\u001b[0m \u001b[38;5;66;03m# enable train mode\u001b[39;00m\n\u001b[1;32m 1198\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1267\u001b[0m, in \u001b[0;36mTrainer._run_sanity_check\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1265\u001b[0m \u001b[38;5;66;03m# run eval step\u001b[39;00m\n\u001b[1;32m 1266\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[0;32m-> 1267\u001b[0m \u001b[43mval_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1269\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_callback_hooks(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mon_sanity_check_end\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1271\u001b[0m \u001b[38;5;66;03m# reset logger connector\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:152\u001b[0m, in \u001b[0;36mEvaluationLoop.advance\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_dataloaders \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 151\u001b[0m kwargs[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataloader_idx\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m dataloader_idx\n\u001b[0;32m--> 152\u001b[0m dl_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdl_max_batches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[38;5;66;03m# store batch level output per dataloader\u001b[39;00m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs\u001b[38;5;241m.\u001b[39mappend(dl_outputs)\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py:199\u001b[0m, in \u001b[0;36mLoop.run\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 198\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 199\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:137\u001b[0m, in \u001b[0;36mEvaluationEpochLoop.advance\u001b[0;34m(self, data_fetcher, dl_max_batches, kwargs)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_started()\n\u001b[1;32m 136\u001b[0m \u001b[38;5;66;03m# lightning module methods\u001b[39;00m\n\u001b[0;32m--> 137\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_evaluation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 138\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_evaluation_step_end(output)\n\u001b[1;32m 140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_progress\u001b[38;5;241m.\u001b[39mincrement_processed()\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py:234\u001b[0m, in \u001b[0;36mEvaluationEpochLoop._evaluation_step\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"The evaluation step (validation_step or test_step depending on the trainer's state).\u001b[39;00m\n\u001b[1;32m 224\u001b[0m \n\u001b[1;32m 225\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 231\u001b[0m \u001b[38;5;124;03m the outputs of the step\u001b[39;00m\n\u001b[1;32m 232\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 233\u001b[0m hook_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtest_step\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mtesting \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvalidation_step\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 234\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhook_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1485\u001b[0m, in \u001b[0;36mTrainer._call_strategy_hook\u001b[0;34m(self, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1482\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 1484\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m-> 1485\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1487\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 1488\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.validation_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mprecision_plugin\u001b[38;5;241m.\u001b[39mval_step_context():\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, ValidationStep)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidation_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[7], line 36\u001b[0m, in \u001b[0;36mMyLightningModule.validation_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 34\u001b[0m attention_mask \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattention_mask\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 35\u001b[0m labels \u001b[38;5;241m=\u001b[39m batch[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[0;32m---> 36\u001b[0m loss, logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval_loss\u001b[39m\u001b[38;5;124m'\u001b[39m, loss, on_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, on_step\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m 38\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mloss\u001b[39m\u001b[38;5;124m'\u001b[39m: loss, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m'\u001b[39m: logits, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m:labels}\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"Cell \u001b[0;32mIn[7], line 16\u001b[0m, in \u001b[0;36mMyLightningModule.forward\u001b[0;34m(self, input_ids, attention_mask, labels)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, input_ids, attention_mask, labels\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m---> 16\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 17\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 18\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mlabels\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output\u001b[38;5;241m.\u001b[39mloss, output\u001b[38;5;241m.\u001b[39mlogits\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:1624\u001b[0m, in \u001b[0;36mT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1621\u001b[0m \u001b[38;5;66;03m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 1622\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoder_outputs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1623\u001b[0m \u001b[38;5;66;03m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 1624\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencoder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1625\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1626\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1627\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1628\u001b[0m \u001b[43m \u001b[49m\u001b[43mhead_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhead_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1629\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1630\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1631\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1632\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1633\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m return_dict \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 1634\u001b[0m encoder_outputs \u001b[38;5;241m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 1635\u001b[0m last_hidden_state\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m0\u001b[39m],\n\u001b[1;32m 1636\u001b[0m hidden_states\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m1\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1637\u001b[0m attentions\u001b[38;5;241m=\u001b[39mencoder_outputs[\u001b[38;5;241m2\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(encoder_outputs) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1638\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/nn/modules/module.py:1194\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1190\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1191\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1192\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1193\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1194\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1195\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1196\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
||||
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py:944\u001b[0m, in \u001b[0;36mT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 940\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 941\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou cannot specify both \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minput_ids and \u001b[39m\u001b[38;5;132;01m{\u001b[39;00merr_msg_prefix\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124minputs_embeds at the same time\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 942\u001b[0m )\n\u001b[1;32m 943\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m input_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 944\u001b[0m input_shape \u001b[38;5;241m=\u001b[39m \u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msize\u001b[49m()\n\u001b[1;32m 945\u001b[0m input_ids \u001b[38;5;241m=\u001b[39m input_ids\u001b[38;5;241m.\u001b[39mview(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, input_shape[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m])\n\u001b[1;32m 946\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute 'size'"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.set_float32_matmul_precision(\"medium\")\n",
|
||||
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
|
||||
"trainer = pl.Trainer(accelerator=\"gpu\", devices=[0], max_epochs=10)\n",
|
||||
"dm = MyDataModule(batch_size=16)\n",
|
||||
"trainer.fit(model, datamodule=dm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "aa7b1ab0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Recap of what we did:\n",
|
||||
"* Finetuned T5-Small on DailyCNN (summarize news articles) using HF Trainer and data loading\n",
|
||||
"* Converted to Lightning code \n",
|
||||
"\n",
|
||||
"### To do next:\n",
|
||||
"* Make it work with the evaluation somethings wrong now, don't think it's a big issue\n",
|
||||
"* Clean up the code a bit\n",
|
||||
"* Compare it with HF, add predict function, modify data loading so it's from scratch / more general way of doing it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "80a2efab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,237 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5372055b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jupyterthemes.stylefx import set_nb_theme\n",
|
||||
"set_nb_theme('chesterish')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "11214a4a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
||||
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f45eb6b0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import datasets \n",
|
||||
"\n",
|
||||
"from datasets import load_dataset, load_metric\n",
|
||||
"\n",
|
||||
"from transformers import (\n",
|
||||
" AutoModel,\n",
|
||||
" AutoModelForMaskedLM,\n",
|
||||
" AutoModelForSeq2SeqLM,\n",
|
||||
" AutoModelForTokenClassification,\n",
|
||||
" AutoTokenizer,\n",
|
||||
" DataCollatorForSeq2Seq,\n",
|
||||
" pipeline,\n",
|
||||
" Seq2SeqTrainingArguments,\n",
|
||||
" Seq2SeqTrainer,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "b2d26af4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "363045f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def preprocess_function(batch):\n",
|
||||
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
||||
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
||||
" batch[\"input_ids\"] = inputs.input_ids\n",
|
||||
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
||||
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
||||
" return batch\n",
|
||||
"\n",
|
||||
"# Load the dataset\n",
|
||||
"train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
|
||||
"val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
||||
"\n",
|
||||
"train_ds = train_data.map(\n",
|
||||
" preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=256, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_ds = val_data.map(\n",
|
||||
" preprocess_function, \n",
|
||||
" batched=True, \n",
|
||||
" batch_size=256, \n",
|
||||
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0d58818f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MyLightningModule(pl.LightningModule):\n",
|
||||
" def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
|
||||
" super().__init__()\n",
|
||||
" self.model_name = model_name\n",
|
||||
" self.learning_rate = learning_rate\n",
|
||||
" self.weight_decay = weight_decay\n",
|
||||
" self.batch_size = batch_size\n",
|
||||
" self.num_training_steps = num_training_steps\n",
|
||||
" \n",
|
||||
" # Load the pre-trained model and tokenizer\n",
|
||||
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
|
||||
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
||||
"\n",
|
||||
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
||||
" output = self.model(\n",
|
||||
" input_ids=input_ids,\n",
|
||||
" attention_mask=attention_mask,\n",
|
||||
" labels=labels,\n",
|
||||
" )\n",
|
||||
" return output.loss, output.logits\n",
|
||||
" \n",
|
||||
" def training_step(self, batch, batch_idx):\n",
|
||||
" input_ids = batch[\"input_ids\"]\n",
|
||||
" attention_mask = batch[\"attention_mask\"]\n",
|
||||
" labels = batch[\"labels\"]\n",
|
||||
" \n",
|
||||
" loss\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=1e-5,\n",
|
||||
" per_device_train_batch_size=16,\n",
|
||||
" per_device_eval_batch_size=16,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=1e-4,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
" generation_max_length=128,\n",
|
||||
" predict_with_generate=True,\n",
|
||||
" logging_steps=100,\n",
|
||||
" gradient_accumulation_steps=1,\n",
|
||||
" fp16=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_data,\n",
|
||||
" eval_dataset=val_data,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5148159b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Steps:\n",
|
||||
"1. Rewrite code to be more general\n",
|
||||
"\n",
|
||||
"a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
|
||||
"\n",
|
||||
"b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95e33e40",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!nvidia-smi"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "4c0348c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,644 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "7d5e92c6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[{'entity': 'I-FOOD', 'score': 0.49999642, 'index': 5, 'word': 'Turtle', 'start': 8, 'end': 14}, {'entity': 'I-FOOD', 'score': 0.6096488, 'index': 6, 'word': '##s', 'start': 14, 'end': 15}, {'entity': 'B-FOOD', 'score': 0.45608267, 'index': 7, 'word': 'Original', 'start': 16, 'end': 24}, {'entity': 'I-FOOD', 'score': 0.6613699, 'index': 8, 'word': 'Cara', 'start': 25, 'end': 29}, {'entity': 'I-FOOD', 'score': 0.5776781, 'index': 9, 'word': '##mel', 'start': 29, 'end': 32}, {'entity': 'I-FOOD', 'score': 0.86556953, 'index': 10, 'word': 'Chocolate', 'start': 33, 'end': 42}, {'entity': 'I-FOOD', 'score': 0.96111995, 'index': 11, 'word': 'P', 'start': 43, 'end': 44}, {'entity': 'I-FOOD', 'score': 0.8003402, 'index': 12, 'word': '##eca', 'start': 44, 'end': 47}, {'entity': 'I-FOOD', 'score': 0.9277613, 'index': 13, 'word': '##n', 'start': 47, 'end': 48}, {'entity': 'I-FOOD', 'score': 0.9217512, 'index': 15, 'word': '##luster', 'start': 50, 'end': 56}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
|
||||
"from transformers import pipeline\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(\"Dizex/FoodBaseBERT\")\n",
|
||||
"\n",
|
||||
"pipe = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
|
||||
"example = \"Demet's Turtles Original Caramel Chocolate Pecan Clusters 9.3 oz Holiday Gift Box\"\n",
|
||||
"\n",
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"print(ner_entity_results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"id": "bf67ee76",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Turtle s Original Cara mel Chocolate P eca n luster\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ner_entity_results = pipe(example)\n",
|
||||
"\n",
|
||||
"# Initialize the entity words list with an empty string\n",
|
||||
"entity_words = [\"\"]\n",
|
||||
"\n",
|
||||
"# Loop through each dictionary in the list and extract the entity word\n",
|
||||
"for result in ner_entity_results:\n",
|
||||
" if result[\"entity\"] == \"B-FOOD\":\n",
|
||||
" entity_words.append(result[\"word\"])\n",
|
||||
" elif result[\"entity\"] == \"I-FOOD\":\n",
|
||||
" entity_words[-1] += \" \" + result[\"word\"]\n",
|
||||
"\n",
|
||||
"# Remove any remaining ## symbols and extra spaces\n",
|
||||
"entity_words = [word.replace(\"##\", \"\").strip() for word in entity_words]\n",
|
||||
"\n",
|
||||
"# Join the entity words into a single string\n",
|
||||
"output = \" \".join(entity_words)\n",
|
||||
"\n",
|
||||
"print(output)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc8e5ea0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"print(torch.cuda.is_available())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d8a1e039",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"import numpy as np"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6ad73024",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier = pipeline(\"zero-shot-classification\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "04f7e02c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"classifier(\n",
|
||||
" \"This is a course about the Transformers library\",\n",
|
||||
" candidate_labels=[\"machine learning\", \"gym\", \"food\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6fb246c2",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import pipeline\n",
|
||||
"generator = pipeline(task=\"text-generation\", model=\"bigscience/bloom-1b7\", device=0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c4e174f0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForTokenClassification, AutoModel, AutoTokenizer\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# Define input text and pre-trained model checkpoint\n",
|
||||
"text = \"My name is wolfgang and I live in berlin\"\n",
|
||||
"checkpoint = \"Jean-Baptiste/roberta-large-ner-english\"\n",
|
||||
"\n",
|
||||
"# Instantiate tokenizer and encode input text\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"inputs = tokenizer(text, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# Instantiate model and generate output\n",
|
||||
"model = AutoModel.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"print(outputs[0].shape)\n",
|
||||
"\n",
|
||||
"# Instantiate token classification model and generate predictions\n",
|
||||
"model = AutoModelForTokenClassification.from_pretrained(checkpoint)\n",
|
||||
"outputs = model(**inputs)\n",
|
||||
"predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)\n",
|
||||
"print(predictions)\n",
|
||||
"print(model.config.id2label)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8212bbaa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# prepare input\n",
|
||||
"text = \"Replace me by any text you'd like.\"\n",
|
||||
"encoded_input = tokenizer(text, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"# forward pass\n",
|
||||
"output = model(**encoded_input)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "314cba41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n",
|
||||
"\n",
|
||||
"# Load the pre-trained tokenizer and model\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-large')\n",
|
||||
"model = AutoModelForMaskedLM.from_pretrained(\"xlm-roberta-large\")\n",
|
||||
"\n",
|
||||
"# Define the input sentence with a masked token\n",
|
||||
"text = \"I want to <mask> a new car tomorrow.\"\n",
|
||||
"\n",
|
||||
"# Tokenize the input sentence, replacing the masked token with a special [MASK] token\n",
|
||||
"encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')\n",
|
||||
"\n",
|
||||
"print(output.logits.shape)\n",
|
||||
"print(encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id))\n",
|
||||
"\n",
|
||||
"# Extract the predicted probabilities for the masked token\n",
|
||||
"predicted_probabilities = output.logits[0, encoded_input['input_ids'][0].tolist().index(tokenizer.mask_token_id)]\n",
|
||||
"predicted_probabilities = torch.nn.functional.softmax(predicted_probabilities, dim=-1)\n",
|
||||
"\n",
|
||||
"# Get the top-k most probable predictions for the masked token\n",
|
||||
"k = 5\n",
|
||||
"top_k = torch.topk(predicted_probabilities, k)\n",
|
||||
"for i in range(k):\n",
|
||||
" token = tokenizer.convert_ids_to_tokens(top_k.indices[i].item())\n",
|
||||
" score = top_k.values[i].item()\n",
|
||||
" print(f\"Prediction {i+1}: '{token}' with probability {score:.5f}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "6187e77e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n",
|
||||
"\n",
|
||||
"sequences = [\n",
|
||||
" \"Using a Transformer network is simple\",\n",
|
||||
" \"The quick brown fox jumps over the lazy dog\",\n",
|
||||
" \"To be or not to be, that is the question\"\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"# Tokenize the input sequences and convert them to padded and truncated integer token IDs\n",
|
||||
"inputs = tokenizer(\n",
|
||||
" sequences,\n",
|
||||
" padding=True,\n",
|
||||
" truncation=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Print the resulting input IDs and attention masks\n",
|
||||
"print(inputs['input_ids'])\n",
|
||||
"print(inputs['attention_mask'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fc259c5a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43466db6",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Huggingface:\n",
|
||||
"\n",
|
||||
"1. Understanding how to use the Pipeline (probably most useful) for various tasks, easy to use, and the different subtasks it can do like translation, QA, zero shot, sentiment analysis, token classification, etc. \n",
|
||||
"2. Understood how pipeline works in more detail by using AutoModel for various tasks as well as AutoTokenizer\n",
|
||||
"3. Load dataset\n",
|
||||
"4. How to finetune\n",
|
||||
"5. How to evaluate\n",
|
||||
"6. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "97c474f2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3ed5d8c2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from transformers import AdamW, AutoTokenizer, AutoModelForSequenceClassification\n",
|
||||
"\n",
|
||||
"# Same as before\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint)\n",
|
||||
"sequences = [\n",
|
||||
" \"I've been waiting for a HuggingFace course my whole life.\",\n",
|
||||
" \"This course is amazing!\",\n",
|
||||
"]\n",
|
||||
"batch = tokenizer(sequences, padding=True, truncation=True, return_tensors=\"pt\")\n",
|
||||
"\n",
|
||||
"# This is new\n",
|
||||
"batch[\"labels\"] = torch.tensor([1, 1])\n",
|
||||
"\n",
|
||||
"optimizer = AdamW(model.parameters())\n",
|
||||
"loss = model(**batch).loss\n",
|
||||
"loss.backward()\n",
|
||||
"optimizer.step()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "c598624f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "cd296227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"raw_train_dataset = raw_datasets[\"train\"]\n",
|
||||
"raw_train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e462947a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"from transformers import AutoTokenizer, DataCollatorWithPadding\n",
|
||||
"raw_datasets = load_dataset(\"glue\", \"mrpc\")\n",
|
||||
"\n",
|
||||
"checkpoint = \"bert-base-uncased\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n",
|
||||
"\n",
|
||||
"def tokenize_function(example):\n",
|
||||
" return tokenizer(example[\"sentence1\"], example[\"sentence2\"], truncation=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)\n",
|
||||
"data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")\n",
|
||||
"\n",
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"import numpy as np\n",
|
||||
"import evaluate\n",
|
||||
"\n",
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0e2795dc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import TrainingArguments\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "3af29cd5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSequenceClassification\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "817f644e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "42819a6c",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def compute_metrics(eval_preds):\n",
|
||||
" metric = evaluate.load(\"glue\", \"mrpc\")\n",
|
||||
" logits, labels = eval_preds\n",
|
||||
" predictions = np.argmax(logits, axis=-1)\n",
|
||||
" return metric.compute(predictions=predictions, references=labels)\n",
|
||||
"\n",
|
||||
"training_args = TrainingArguments(\"test-trainer\", evaluation_strategy=\"epoch\")\n",
|
||||
"model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)\n",
|
||||
"\n",
|
||||
"trainer = Trainer(\n",
|
||||
" model,\n",
|
||||
" training_args,\n",
|
||||
" train_dataset=tokenized_datasets[\"train\"],\n",
|
||||
" eval_dataset=tokenized_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eb5986b0",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"batch_size=32\n",
|
||||
"\n",
|
||||
"# Define the generator function to preprocess the data in batches\n",
|
||||
"def preprocess_generator(examples):\n",
|
||||
" for i in range(0, len(examples[\"article\"]), batch_size):\n",
|
||||
" batch = examples[\"article\"][i:i+batch_size]\n",
|
||||
" targets = examples[\"highlights\"][i:i+batch_size]\n",
|
||||
" model_inputs = tokenizer(batch, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(targets, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" yield model_inputs\n",
|
||||
"\n",
|
||||
"def preprocess_function(examples):\n",
|
||||
" articles = [ex for ex in examples[\"article\"]]\n",
|
||||
" summaries = [ex for ex in examples[\"highlights\"]]\n",
|
||||
"\n",
|
||||
" model_inputs = tokenizer(articles, max_length=512, padding=\"max_length\", truncation=True)\n",
|
||||
" with tokenizer.as_target_tokenizer():\n",
|
||||
" model_targets = tokenizer(summaries, max_length=128, padding=\"max_length\", truncation=True)\n",
|
||||
" \n",
|
||||
" model_inputs[\"labels\"] = model_targets[\"input_ids\"]\n",
|
||||
" return model_inputs\n",
|
||||
" \n",
|
||||
"# Load the dataset\n",
|
||||
"raw_datasets = load_dataset(\"cnn_dailymail\", \"3.0.0\")\n",
|
||||
"preprocessed_datasets = raw_datasets.map(preprocess_function, batched=True, num_proc=4)\n",
|
||||
"\n",
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" evaluation_strategy = \"epoch\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=1000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_ds,\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7d62583e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_metric"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d310a7b3",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"preprocessed_datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "99d422cc",
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load the pre-trained model and tokenizer\n",
|
||||
"model_name = \"t5-small\"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
||||
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"# Define the data collator\n",
|
||||
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
||||
"\n",
|
||||
"# Initialize the trainer arguments\n",
|
||||
"training_args = Seq2SeqTrainingArguments(\n",
|
||||
" output_dir=\"./results\",\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" per_device_train_batch_size=batch_size,\n",
|
||||
" max_steps=5000,\n",
|
||||
" weight_decay=0.01,\n",
|
||||
" push_to_hub=False,\n",
|
||||
" evaluation_strategy = \"steps\",\n",
|
||||
" eval_steps = 50,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Load the ROUGE metric\n",
|
||||
"metric = load_metric(\"rouge\")\n",
|
||||
"\n",
|
||||
"# Define the evaluation function\n",
|
||||
"def compute_metrics(pred):\n",
|
||||
" labels = pred.label_ids\n",
|
||||
" preds = pred.predictions\n",
|
||||
" \n",
|
||||
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
||||
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
||||
" \n",
|
||||
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
||||
" \n",
|
||||
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Initialize the trainer\n",
|
||||
"trainer = Seq2SeqTrainer(\n",
|
||||
" model=model,\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=preprocessed_datasets[\"train\"],\n",
|
||||
" eval_dataset=preprocessed_datasets[\"validation\"],\n",
|
||||
" data_collator=data_collator,\n",
|
||||
" tokenizer=tokenizer,\n",
|
||||
" compute_metrics=compute_metrics,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Start the training\n",
|
||||
"trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a5e97b57",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install nltk\n",
|
||||
"!pip install rouge_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "558c3e66",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Goal:\n",
|
||||
"\n",
|
||||
"1. Implement full training from dataloading (dailycnn dataset), to model training, evaluation, etc, using HF. \n",
|
||||
"* Right now: stuck on on the fly dataset loading, we don't want to cache because this would take a lot of disk space etc.\n",
|
||||
"\n",
|
||||
"2. After we get step 1) working, we want to go deeper on every step, so download the dataset and load it as a custom dataset rather than using huggingface simple API, in order to make it more general. Compare with loading the ds as a custom HF dataset or using pytorch class together with lightning. Speed difference? Convenience? Also we want to use the lightning Trainer so see how we can integrate that. And then compare HF to the lightning + hf model approach and see what we like the most."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "624d49ca",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -1,41 +0,0 @@
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, DataCollatorWithPadding
|
||||
from transformers import Trainer
|
||||
|
||||
raw_datasets = load_dataset("glue", "mrpc")
|
||||
checkpoint = "bert-base-uncased"
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
|
||||
|
||||
def tokenize_function(example):
|
||||
return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
|
||||
|
||||
|
||||
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
|
||||
from transformers import TrainingArguments
|
||||
training_args = TrainingArguments("test-trainer")
|
||||
|
||||
from transformers import AutoModelForSequenceClassification
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
||||
|
||||
def compute_metrics(eval_preds):
|
||||
metric = evaluate.load("glue", "mrpc")
|
||||
logits, labels = eval_preds
|
||||
predictions = np.argmax(logits, axis=-1)
|
||||
return metric.compute(predictions=predictions, references=labels)
|
||||
|
||||
training_args = TrainingArguments("test-trainer", evaluation_strategy="epoch")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
|
||||
|
||||
trainer = Trainer(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=tokenized_datasets["train"],
|
||||
eval_dataset=tokenized_datasets["validation"],
|
||||
data_collator=data_collator,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
Binary file not shown.
@@ -1 +0,0 @@
|
||||
{}
|
||||
@@ -1,130 +0,0 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
from datasets import load_dataset, load_metric
|
||||
from transformers import T5Config, T5ForConditionalGeneration
|
||||
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
Seq2SeqTrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
)
|
||||
|
||||
|
||||
class MyLightningModule(pl.LightningModule):
|
||||
def __init__(self, model_name, learning_rate, weight_decay):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.learning_rate = learning_rate
|
||||
self.weight_decay = weight_decay
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Load the pre-trained model and tokenizer
|
||||
#self.model = torch.compile(
|
||||
# AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
|
||||
#)
|
||||
|
||||
# Create a T5-small configuration
|
||||
config = T5Config.from_pretrained("t5-small")
|
||||
|
||||
# Initialize the T5 model with random weights
|
||||
self.model = torch.compile(T5ForConditionalGeneration(config))
|
||||
|
||||
# Load the ROUGE metric
|
||||
self.metric = load_metric("rouge")
|
||||
self.logits = []
|
||||
self.labels = []
|
||||
|
||||
def forward(self, input_ids, attention_mask, labels=None):
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
labels=labels,
|
||||
)
|
||||
return output.loss, output.logits
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
loss, logits = self(input_ids, attention_mask, labels)
|
||||
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True)
|
||||
return {"loss": loss, "logits": logits}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
loss, logits = self(input_ids, attention_mask, labels)
|
||||
self.log("val_loss", loss, on_epoch=True, on_step=False)
|
||||
|
||||
# add logits and labels to instance attributes, but make sure to detach them
|
||||
# from the computational graph first
|
||||
self.logits.append(logits.argmax(dim=-1).detach().cpu())
|
||||
self.labels.append(labels.detach().cpu())
|
||||
return {"loss": loss, "logits": logits, "labels": labels}
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Concatenate tensors in logits and labels lists
|
||||
pred_token_ids = torch.cat(self.logits, dim=0)
|
||||
true_labels = torch.cat(self.labels, dim=0)
|
||||
|
||||
# Decode predictions and labels using the saved instance attributes
|
||||
decoded_preds = self.tokenizer.batch_decode(
|
||||
pred_token_ids, skip_special_tokens=True
|
||||
)
|
||||
decoded_labels = self.tokenizer.batch_decode(
|
||||
true_labels, skip_special_tokens=True
|
||||
)
|
||||
|
||||
# Compute ROUGE scores
|
||||
scores = self.metric.compute(
|
||||
predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"]
|
||||
)["rouge1"].mid
|
||||
|
||||
self.log("rouge1_precision", scores.precision, prog_bar=True)
|
||||
self.log("rouge1_recall", scores.recall, prog_bar=True)
|
||||
self.log("rouge1_fmeasure", scores.fmeasure, prog_bar=True)
|
||||
|
||||
# Clear logits and labels instance attributes for the next validation epoch
|
||||
self.logits.clear()
|
||||
self.labels.clear()
|
||||
|
||||
def predict(self, article: str, max_input_length: int = 512, max_output_length: int = 150) -> str:
|
||||
# Set the model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
# Tokenize the input article
|
||||
inputs = self.tokenizer(
|
||||
article,
|
||||
max_length=max_input_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move the input tensors to the same device as the model
|
||||
inputs = {key: value.to(self.device) for key, value in inputs.items()}
|
||||
|
||||
# Generate summary
|
||||
with torch.no_grad():
|
||||
output = self.model.generate(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
max_length=max_output_length,
|
||||
num_return_sequences=1,
|
||||
)
|
||||
|
||||
# Decode and return the summary
|
||||
summary = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
return summary
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.AdamW(
|
||||
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
l = ["cat", "dog"]
|
||||
sentence = "The quick brown fox jumps over the lazy dog"
|
||||
@@ -1,67 +0,0 @@
|
||||
from dataset import MyDataModule
|
||||
from model import MyLightningModule
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoTokenizer,
|
||||
DataCollatorForSeq2Seq,
|
||||
Seq2SeqTrainingArguments,
|
||||
Seq2SeqTrainer,
|
||||
)
|
||||
import torch
|
||||
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define the checkpoint callback
|
||||
checkpoint_callback = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
dirpath="checkpoints",
|
||||
filename="my_model-{epoch:02d}-{val_loss:.2f}",
|
||||
save_top_k=-1,
|
||||
every_n_epochs=1,
|
||||
verbose=True,
|
||||
)
|
||||
logger = TensorBoardLogger("tb_logs", name="t5_dailymail")
|
||||
|
||||
model_name = "t5-small"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# File paths
|
||||
train_csv = "train.csv"
|
||||
val_csv = "validation.csv"
|
||||
test_csv = "test.csv"
|
||||
|
||||
# Create the data module
|
||||
dm = MyDataModule(train_csv, val_csv, test_csv, tokenizer, batch_size=32)
|
||||
dm.setup()
|
||||
|
||||
model = MyLightningModule(
|
||||
model_name="t5-small", learning_rate=1e-4, weight_decay=1e-5
|
||||
)
|
||||
|
||||
|
||||
#checkpoint_path = "checkpoints/curr.ckpt"
|
||||
#checkpoint = torch.load(checkpoint_path)
|
||||
#model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
trainer = pl.Trainer(
|
||||
accelerator="gpu",
|
||||
devices=[0, 1],
|
||||
max_epochs=10,
|
||||
precision=16,
|
||||
logger=logger,
|
||||
callbacks=[checkpoint_callback],
|
||||
log_every_n_steps=10,
|
||||
)
|
||||
trainer.fit(model, dm)
|
||||
trainer.validate(model, dm)
|
||||
|
||||
#example = """Former President Donald Trump claims in a social media post that he will be arrested next week. The claim comes while a New York prosecutor considers charging Trump in connection with hush money paid to adult film actress Stormy Daniels but there has been no official announcement of any plans for an indictment. What we know about Trump possibly facing criminal indictment in New York City. Trump has been entangled in several criminal investigations but the case related to Daniels is the longest-running of all of them, reaching back to 2016. On his platform Truth Social on Saturday morning, Trump cited "illegal leaks" that he will be arrested Tuesday and he called for protests. Trump, who is running for president in 2024, also defended himself, saying that he has not committed a crime — though he did not disclose what he expects to be charged with — and he accused the Manhattan District Attorney's Office of being "corrupt & highly political.". 'I'M BACK!' Trump posts on Facebook, YouTube for first time in two years. The Manhattan District Attorney's Office declined to comment on whether it will soon be pursing an arrest warrant for Trump. But the Associated Press reported that law enforcement officials in New York are discussing security preparations in anticipation that Trump may be indicted in coming weeks. If it does occur, Trump would become the first former president to be indicted in U.S. history."""
|
||||
#print(len(tokenizer(example)["input_ids"]))
|
||||
#summary = model.predict(example)
|
||||
#print(summary)
|
||||
Reference in New Issue
Block a user