{ "cells": [ { "cell_type": "code", "execution_count": 7, "id": "bd8e3b95", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jupyterthemes.stylefx import set_nb_theme\n", "set_nb_theme('chesterish')" ] }, { "cell_type": "code", "execution_count": 2, "id": "8c2a24cb", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "f45eb6b0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "2023-02-21 15:40:52.888700: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", "2023-02-21 15:40:53.473104: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n", "2023-02-21 15:40:53.473149: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n", "2023-02-21 15:40:53.473154: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" ] } ], "source": [ "import numpy as np\n", "import torch\n", "\n", "import datasets \n", "\n", "from datasets import load_dataset, load_metric\n", "\n", "from transformers import (\n", " AutoModel,\n", " AutoModelForMaskedLM,\n", " AutoModelForSeq2SeqLM,\n", " AutoModelForTokenClassification,\n", " AutoTokenizer,\n", " DataCollatorForSeq2Seq,\n", " pipeline,\n", " Seq2SeqTrainingArguments,\n", " Seq2SeqTrainer,\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "7fc4eb40", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", " warnings.warn(\n" ] } ], "source": [ "# Load the pre-trained model and tokenizer\n", "model_name = \"t5-small\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "model = AutoModelForSeq2SeqLM.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 5, "id": "363045f5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n", "Found cached dataset cnn_dailymail (/home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)\n", "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1122/1122 [02:06<00:00, 8.88ba/s]\n", "Loading cached processed dataset at /home/mrbean/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de/cache-2d3b7edd75fb1188.arrow\n" ] } ], "source": [ "def preprocess_function(batch):\n", " inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n", " outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n", " batch[\"input_ids\"] = inputs.input_ids\n", " batch[\"attention_mask\"] = inputs.attention_mask\n", " batch[\"labels\"] = outputs.input_ids.copy()\n", " return batch\n", "\n", "# Load the dataset\n", "train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n", "val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n", "\n", "train_ds = train_data.map(\n", " preprocess_function, \n", " batched=True, \n", " batch_size=256, \n", " remove_columns=[\"article\", \"highlights\", \"id\"]\n", ")\n", "\n", "val_ds = val_data.map(\n", " preprocess_function, \n", " batched=True, \n", " batch_size=256, \n", " remove_columns=[\"article\", \"highlights\", \"id\"]\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "id": "6faa8c86", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_478601/1088570042.py:23: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n", " metric = load_metric(\"rouge\")\n", "max_steps is given, it will override any value given in num_train_epochs\n", "Using cuda_amp half precision backend\n", "The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: id, article, highlights. If id, article, highlights are not expected by `T5ForConditionalGeneration.forward`, you can safely ignore this message.\n", "/home/mrbean/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 0\n", " Num Epochs = 1\n", " Instantaneous batch size per device = 16\n", " Total train batch size (w. parallel, distributed & accumulation) = 16\n", " Gradient Accumulation steps = 1\n", " Total optimization steps = 5000\n", " Number of trainable parameters = 60506624\n" ] }, { "ename": "IndexError", "evalue": "Invalid key: 90427 is out of bounds for size 0", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[6], line 47\u001b[0m\n\u001b[1;32m 36\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Seq2SeqTrainer(\n\u001b[1;32m 37\u001b[0m model\u001b[38;5;241m=\u001b[39mmodel,\n\u001b[1;32m 38\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 43\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics,\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;66;03m# Start the training\u001b[39;00m\n\u001b[0;32m---> 47\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel_wrapped \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel\n\u001b[1;32m 1536\u001b[0m inner_training_loop \u001b[38;5;241m=\u001b[39m find_executable_batch_size(\n\u001b[1;32m 1537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_inner_training_loop, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_train_batch_size, args\u001b[38;5;241m.\u001b[39mauto_find_batch_size\n\u001b[1;32m 1538\u001b[0m )\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/transformers/trainer.py:1761\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1758\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_load_rng_state(resume_from_checkpoint)\n\u001b[1;32m 1760\u001b[0m step \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m\n\u001b[0;32m-> 1761\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step, inputs \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(epoch_iterator):\n\u001b[1;32m 1762\u001b[0m \n\u001b[1;32m 1763\u001b[0m \u001b[38;5;66;03m# Skip past any already trained steps if resuming training\u001b[39;00m\n\u001b[1;32m 1764\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m steps_trained_in_current_epoch \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 1765\u001b[0m steps_trained_in_current_epoch \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:628\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 625\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 626\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 627\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 628\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 629\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 631\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/dataloader.py:671\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 669\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 670\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 671\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 672\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 673\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[idx] \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:58\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 56\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 58\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 59\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 60\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601\u001b[0m, in \u001b[0;36mDataset.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 2599\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, key): \u001b[38;5;66;03m# noqa: F811\u001b[39;00m\n\u001b[1;32m 2600\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools).\"\"\"\u001b[39;00m\n\u001b[0;32m-> 2601\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_getitem\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2602\u001b[0m \u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2603\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/arrow_dataset.py:2585\u001b[0m, in \u001b[0;36mDataset._getitem\u001b[0;34m(self, key, **kwargs)\u001b[0m\n\u001b[1;32m 2583\u001b[0m format_kwargs \u001b[38;5;241m=\u001b[39m format_kwargs \u001b[38;5;28;01mif\u001b[39;00m format_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m {}\n\u001b[1;32m 2584\u001b[0m formatter \u001b[38;5;241m=\u001b[39m get_formatter(format_type, features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfeatures, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mformat_kwargs)\n\u001b[0;32m-> 2585\u001b[0m pa_subtable \u001b[38;5;241m=\u001b[39m \u001b[43mquery_table\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_indices\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mis\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m 2586\u001b[0m formatted_output \u001b[38;5;241m=\u001b[39m format_table(\n\u001b[1;32m 2587\u001b[0m pa_subtable, key, formatter\u001b[38;5;241m=\u001b[39mformatter, format_columns\u001b[38;5;241m=\u001b[39mformat_columns, output_all_columns\u001b[38;5;241m=\u001b[39moutput_all_columns\n\u001b[1;32m 2588\u001b[0m )\n\u001b[1;32m 2589\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m formatted_output\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:588\u001b[0m, in \u001b[0;36mquery_table\u001b[0;34m(table, key, indices)\u001b[0m\n\u001b[1;32m 586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 587\u001b[0m size \u001b[38;5;241m=\u001b[39m indices\u001b[38;5;241m.\u001b[39mnum_rows \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m table\u001b[38;5;241m.\u001b[39mnum_rows\n\u001b[0;32m--> 588\u001b[0m \u001b[43m_check_valid_index_key\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msize\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 589\u001b[0m \u001b[38;5;66;03m# Query the main table\u001b[39;00m\n\u001b[1;32m 590\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m indices \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", "File \u001b[0;32m~/.conda/envs/whisper_lightning/lib/python3.10/site-packages/datasets/formatting/formatting.py:531\u001b[0m, in \u001b[0;36m_check_valid_index_key\u001b[0;34m(key, size)\u001b[0m\n\u001b[1;32m 529\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mint\u001b[39m):\n\u001b[1;32m 530\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (key \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m key \u001b[38;5;241m+\u001b[39m size \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (key \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m size):\n\u001b[0;32m--> 531\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInvalid key: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkey\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m is out of bounds for size \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msize\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 532\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m\n\u001b[1;32m 533\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, \u001b[38;5;28mslice\u001b[39m):\n", "\u001b[0;31mIndexError\u001b[0m: Invalid key: 90427 is out of bounds for size 0" ] } ], "source": [ "class MyLightningModule(pl.LightningModule):\n", " def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n", " super().__init__()\n", " self.model_name = model_name\n", " self.learning_rate = learning_rate\n", " self.weight_decay = weight_decay\n", " self.batch_size = batch_size\n", " self.num_training_steps = num_training_steps\n", " \n", " # Load the pre-trained model and tokenizer\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n", " self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n", "\n", " def forward(self, input_ids, attention_mask, labels=None):\n", " output = self.model(\n", " input_ids=input_ids,\n", " attention_mask=attention_mask,\n", " labels=labels,\n", " )\n", " return output.loss, output.logits\n", " \n", " def training_step(self, batch, batch_idx):\n", " input_ids = batch[\"input_ids\"]\n", " attention_mask = batch[\"attention_mask\"]\n", " labels = batch[\"labels\"]\n", " \n", " loss\n", "\n", "# Define the data collator\n", "data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n", "\n", "# Initialize the trainer arguments\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./results\",\n", " learning_rate=1e-5,\n", " per_device_train_batch_size=16,\n", " per_device_eval_batch_size=16,\n", " max_steps=5000,\n", " weight_decay=1e-4,\n", " push_to_hub=False,\n", " evaluation_strategy = \"steps\",\n", " eval_steps = 50,\n", " generation_max_length=128,\n", " predict_with_generate=True,\n", " logging_steps=100,\n", " gradient_accumulation_steps=1,\n", " fp16=True,\n", ")\n", "\n", "# Load the ROUGE metric\n", "metric = load_metric(\"rouge\")\n", "\n", "# Define the evaluation function\n", "def compute_metrics(pred):\n", " labels = pred.label_ids\n", " preds = pred.predictions\n", " decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n", " decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n", " scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n", " return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n", "\n", "\n", "# Initialize the trainer\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_data,\n", " eval_dataset=val_data,\n", " data_collator=data_collator,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "# Start the training\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "1b0f9a76", "metadata": {}, "source": [ "# Steps:\n", "1. Rewrite code to be more general\n", "\n", "a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n", "\n", "b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface" ] }, { "cell_type": "code", "execution_count": null, "id": "ff03c8bb", "metadata": {}, "outputs": [], "source": [ "!nvidia-smi" ] }, { "cell_type": "code", "execution_count": null, "id": "aafc4b27", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }