{ "cells": [ { "cell_type": "code", "execution_count": 6, "id": "87ef8027", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jupyterthemes.stylefx import set_nb_theme\n", "set_nb_theme('chesterish')" ] }, { "cell_type": "code", "execution_count": 17, "id": "225eab36", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "warnings.simplefilter(\"ignore\")\n", "\n", "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", "\n", "import numpy as np\n", "import torch\n", "\n", "import datasets \n", "import pytorch_lightning as pl\n", "\n", "from datasets import load_dataset, load_metric\n", "\n", "from transformers import (\n", " AutoModel,\n", " AutoModelForSeq2SeqLM,\n", " AutoTokenizer,\n", " DataCollatorForSeq2Seq,\n", " Seq2SeqTrainingArguments,\n", " Seq2SeqTrainer,\n", ")" ] }, { "cell_type": "code", "execution_count": 18, "id": "9f7d2829", "metadata": {}, "outputs": [], "source": [ "# Define the LightningDataModule\n", "class MyDataModule(pl.LightningDataModule):\n", " def __init__(self, batch_size):\n", " super().__init__()\n", " self.batch_size = batch_size\n", " \n", " def prepare_data(self):\n", " # Download and preprocess the data\n", " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n", " load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n", " \n", " def setup(self, stage=None):\n", " # Load and preprocess the data\n", " train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n", " val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n", "\n", " self.train_ds = train_data.map(\n", " self.preprocess_function, \n", " batched=True, \n", " batch_size=self.batch_size, \n", " remove_columns=[\"article\", \"highlights\", \"id\"]\n", " )\n", "\n", " self.val_ds = val_data.map(\n", " self.preprocess_function, \n", " batched=True, \n", " batch_size=self.batch_size,\n", " remove_columns=[\"article\", \"highlights\", \"id\"]\n", " )\n", "\n", " def preprocess_function(self, batch):\n", " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n", " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n", " batch[\"input_ids\"] = inputs.input_ids\n", " batch[\"attention_mask\"] = inputs.attention_mask\n", " batch[\"labels\"] = outputs.input_ids.copy()\n", " return batch\n", "\n", " def train_dataloader(self):\n", " return torch.utils.data.DataLoader(self.train_ds, batch_size=self.batch_size)\n", "\n", " def val_dataloader(self):\n", " return torch.utils.data.DataLoader(self.val_ds, batch_size=self.batch_size)" ] }, { "cell_type": "code", "execution_count": 19, "id": "a99bdbb0", "metadata": {}, "outputs": [], "source": [ "class MyLightningModule(pl.LightningModule):\n", " def __init__(self, model_name, learning_rate, weight_decay, batch_size):\n", " super().__init__()\n", " self.model_name = model_name\n", " self.learning_rate = learning_rate\n", " self.weight_decay = weight_decay\n", " self.batch_size = batch_size\n", " \n", " # Load the pre-trained model and tokenizer\n", " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n", "\n", " # Load the ROUGE metric\n", " self.metric = load_metric(\"rouge\")\n", "\n", " def forward(self, input_ids, attention_mask, labels=None):\n", " output = self.model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", " return output.loss, output.logits\n", " \n", " def training_step(self, batch, batch_idx):\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels = batch[\"labels\"]\n", " \n", " loss, logits = self(input_ids, attention_mask, labels)\n", " self.log('train_loss', loss, on_epoch=True, on_step=True)\n", " return {'loss': loss, 'logits': logits}\n", " \n", " def validation_step(self, batch, batch_idx):\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels = batch[\"labels\"]\n", " loss, logits = self(input_ids, attention_mask, labels)\n", " self.log('val_loss', loss, on_epoch=True, on_step=False)\n", " return {'loss': loss, 'logits': logits}\n", " \n", " def validation_epoch_end(self, outputs):\n", " decoded_preds = []\n", " decoded_labels = []\n", " for output in outputs:\n", " logits = output['logits']\n", " labels = output['labels']\n", " decoded_preds += self.tokenizer.batch_decode(logits, skip_special_tokens=True)\n", " decoded_labels += self.tokenizer.batch_decode(labels, skip_special_tokens=True)\n", " \n", " scores = self.metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n", " \n", " self.log('rouge1_precision', scores.precision, prog_bar=True)\n", " self.log('rouge1_recall', scores.recall, prog_bar=True)\n", " self.log('rouge1_fmeasure', scores.fmeasure, prog_bar=True)\n", " \n", " def configure_optimizers(self):\n", " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n", " return optimizer\n" ] }, { "cell_type": "code", "execution_count": 20, "id": "3c28da7c", "metadata": {}, "outputs": [ { "ename": "TypeError", "evalue": "Trainer.__init__() got an unexpected keyword argument 'num_epochs'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[20], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m MyLightningModule(model_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt5-small\u001b[39m\u001b[38;5;124m\"\u001b[39m, learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-5\u001b[39m, weight_decay\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1e-4\u001b[39m, batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTrainer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdevices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m10\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlogger\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m dm \u001b[38;5;241m=\u001b[39m MyDataModule(batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m)\n\u001b[1;32m 4\u001b[0m trainer\u001b[38;5;241m.\u001b[39mfit(model, datamodule\u001b[38;5;241m=\u001b[39mdm)\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/pytorch_lightning/utilities/argparse.py:348\u001b[0m, in \u001b[0;36m_defaults_from_env_vars..insert_env_defaults\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 345\u001b[0m kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m(\u001b[38;5;28mlist\u001b[39m(env_variables\u001b[38;5;241m.\u001b[39mitems()) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mitems()))\n\u001b[1;32m 347\u001b[0m \u001b[38;5;66;03m# all args were already moved to kwargs\u001b[39;00m\n\u001b[0;32m--> 348\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mTypeError\u001b[0m: Trainer.__init__() got an unexpected keyword argument 'num_epochs'" ] } ], "source": [ "model = MyLightningModule(model_name=\"t5-small\", learning_rate=1e-5, weight_decay=1e-4, batch_size=16)\n", "trainer = pl.Trainer(devices=[0], num_epochs=10, deterministic=True, logger=False)\n", "dm = MyDataModule(batch_size=16)\n", "trainer.fit(model, datamodule=dm)" ] }, { "cell_type": "code", "execution_count": null, "id": "55729d94", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }