Files
Machine-Learning-Collection/ML/Pytorch/huggingface/finetuning_t5_small_cnndaily.ipynb
Aladdin Persson e4659fe56a huggingface update
2023-03-18 09:51:16 +01:00

238 lines
7.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "5372055b",
"metadata": {},
"outputs": [],
"source": [
"from jupyterthemes.stylefx import set_nb_theme\n",
"set_nb_theme('chesterish')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11214a4a",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f45eb6b0",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"import datasets \n",
"\n",
"from datasets import load_dataset, load_metric\n",
"\n",
"from transformers import (\n",
" AutoModel,\n",
" AutoModelForMaskedLM,\n",
" AutoModelForSeq2SeqLM,\n",
" AutoModelForTokenClassification,\n",
" AutoTokenizer,\n",
" DataCollatorForSeq2Seq,\n",
" pipeline,\n",
" Seq2SeqTrainingArguments,\n",
" Seq2SeqTrainer,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2d26af4",
"metadata": {},
"outputs": [],
"source": [
"# Load the pre-trained model and tokenizer\n",
"model_name = \"t5-small\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "363045f5",
"metadata": {},
"outputs": [],
"source": [
"def preprocess_function(batch):\n",
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
" batch[\"input_ids\"] = inputs.input_ids\n",
" batch[\"attention_mask\"] = inputs.attention_mask\n",
" batch[\"labels\"] = outputs.input_ids.copy()\n",
" return batch\n",
"\n",
"# Load the dataset\n",
"train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
"val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
"\n",
"train_ds = train_data.map(\n",
" preprocess_function, \n",
" batched=True, \n",
" batch_size=256, \n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
")\n",
"\n",
"val_ds = val_data.map(\n",
" preprocess_function, \n",
" batched=True, \n",
" batch_size=256, \n",
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d58818f",
"metadata": {},
"outputs": [],
"source": [
"class MyLightningModule(pl.LightningModule):\n",
" def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
" super().__init__()\n",
" self.model_name = model_name\n",
" self.learning_rate = learning_rate\n",
" self.weight_decay = weight_decay\n",
" self.batch_size = batch_size\n",
" self.num_training_steps = num_training_steps\n",
" \n",
" # Load the pre-trained model and tokenizer\n",
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
"\n",
" def forward(self, input_ids, attention_mask, labels=None):\n",
" output = self.model(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" labels=labels,\n",
" )\n",
" return output.loss, output.logits\n",
" \n",
" def training_step(self, batch, batch_idx):\n",
" input_ids = batch[\"input_ids\"]\n",
" attention_mask = batch[\"attention_mask\"]\n",
" labels = batch[\"labels\"]\n",
" \n",
" loss\n",
"\n",
"# Define the data collator\n",
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
"\n",
"# Initialize the trainer arguments\n",
"training_args = Seq2SeqTrainingArguments(\n",
" output_dir=\"./results\",\n",
" learning_rate=1e-5,\n",
" per_device_train_batch_size=16,\n",
" per_device_eval_batch_size=16,\n",
" max_steps=5000,\n",
" weight_decay=1e-4,\n",
" push_to_hub=False,\n",
" evaluation_strategy = \"steps\",\n",
" eval_steps = 50,\n",
" generation_max_length=128,\n",
" predict_with_generate=True,\n",
" logging_steps=100,\n",
" gradient_accumulation_steps=1,\n",
" fp16=True,\n",
")\n",
"\n",
"# Load the ROUGE metric\n",
"metric = load_metric(\"rouge\")\n",
"\n",
"# Define the evaluation function\n",
"def compute_metrics(pred):\n",
" labels = pred.label_ids\n",
" preds = pred.predictions\n",
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
"\n",
"\n",
"# Initialize the trainer\n",
"trainer = Seq2SeqTrainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_data,\n",
" eval_dataset=val_data,\n",
" data_collator=data_collator,\n",
" tokenizer=tokenizer,\n",
" compute_metrics=compute_metrics,\n",
")\n",
"\n",
"# Start the training\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "5148159b",
"metadata": {},
"source": [
"# Steps:\n",
"1. Rewrite code to be more general\n",
"\n",
"a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
"\n",
"b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95e33e40",
"metadata": {},
"outputs": [],
"source": [
"!nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c0348c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}