{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "87ef8027",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
""
],
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from jupyterthemes.stylefx import set_nb_theme\n",
"set_nb_theme('chesterish')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "225eab36",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.simplefilter(\"ignore\")\n",
"\n",
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n",
"\n",
"import numpy as np\n",
"import torch\n",
"\n",
"import datasets \n",
"import pytorch_lightning as pl\n",
"\n",
"from datasets import load_dataset, load_metric\n",
"\n",
"from transformers import (\n",
" AutoModel,\n",
" AutoModelForSeq2SeqLM,\n",
" AutoTokenizer,\n",
" DataCollatorForSeq2Seq,\n",
" Seq2SeqTrainingArguments,\n",
" Seq2SeqTrainer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "9f7d2829",
"metadata": {},
"outputs": [],
"source": [
"# Define the LightningDataModule\n",
"class MyDataModule(pl.LightningDataModule):\n",
" def __init__(self, batch_size):\n",
" super().__init__()\n",
" self.batch_size = batch_size\n",
" \n",
" def prepare_data(self):\n",
" # Download and preprocess the data\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
" load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
" \n",
" def setup(self, stage=None):\n",
" # Load and preprocess the data\n",
" train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
" val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
"\n",
" self.train_ds = train_data.map(\n",
" self.preprocess_function, \n",
" batched=True, \n",
" batch_size=self.batch_size, \n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
" )\n",
"\n",
" self.val_ds = val_data.map(\n",
" self.preprocess_function, \n",
" batched=True, \n",
" batch_size=self.batch_size,\n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
" )\n",
"\n",
" def preprocess_function(self, batch):\n",
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
" batch[\"input_ids\"] = inputs.input_ids\n",
" batch[\"attention_mask\"] = inputs.attention_mask\n",
" batch[\"labels\"] = outputs.input_ids.copy()\n",
" return batch\n",
"\n",
" def train_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n",
"\n",
" def val_dataloader(self):\n",
" return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a99bdbb0",
"metadata": {},
"outputs": [],
"source": [
"class MyLightningModule(pl.LightningModule):\n",
" def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n",
" super().__init__()\n",
" self.model_name = model_name\n",
" self.learning_rate = learning_rate\n",
" self.weight_decay = weight_decay\n",
" self.batch_size = batch_size\n",
" \n",
" # Load the pre-trained model and tokenizer\n",
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
"\n",
" # Load the ROUGE metric\n",
" self.metric = load_metric(\"rouge\")\n",
"\n",
" def forward(self, input_ids, attention_mask, labels=None):\n",
" output = self.model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" labels=labels,\n",
" )\n",
" return output.loss, output.logits\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
" \n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('train_loss', loss, on_epoch=True, on_step=True)\n",
" return {'loss': loss, 'logits': logits}\n",
" \n",
" def validation_step(self, batch, batch_idx):\n",
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
" loss, logits = self(input_ids, attention_mask, labels)\n",
" self.log('val_loss', loss, on_epoch=True, on_step=False)\n",
" return {'loss': loss, 'logits': logits}\n",
" \n",
" def validation_epoch_end(self, outputs):\n",
" decoded_preds = []\n",
" decoded_labels = []\n",
" for output in outputs:\n",
" logits = output['logits']\n",
" labels = output['labels']\n",
" decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
" decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
" \n",
" scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
" \n",
" self.log('rouge1_precision', scores.precision, prog_bar=True)\n",
" self.log('rouge1_recall', scores.recall, prog_bar=True)\n",
" self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n",
" \n",
" def configure_optimizers(self):\n",
" optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
" return optimizer\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "3c28da7c",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "Trainer.__init__() got an unexpected keyword argument 'num_epochs'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m MyLightningModule(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt5-small\u001b[39m\u001b[38;5;124m\"\u001b[39m, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, weight_decay\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogger\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[1;32m 4\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(model, datamodule\u001b[38;5;241m=\u001b[39mdm)\n",
"File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:348\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 345\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m 347\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m--> 348\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mTypeError\u001b[0m: Trainer.__init__() got an unexpected keyword argument 'num_epochs'"
]
}
],
"source": [
"model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n",
"trainer = pl.Trainer(devices=[0], num_epochs=10, deterministic=True, logger=False)\n",
"dm = MyDataModule(batch_size=16)\n",
"trainer.fit(model, datamodule=dm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "55729d94",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}