mirror of
https://github.com/aladdinpersson/Machine-Learning-Collection.git
synced 2026-02-20 13:50:41 +00:00
238 lines
7.0 KiB
Plaintext
238 lines
7.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5372055b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from jupyterthemes.stylefx import set_nb_theme\n",
|
|
"set_nb_theme('chesterish')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "11214a4a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"os.environ[\"CUDA_DEVICE_ORDER\"]=\"PCI_BUS_ID\"\n",
|
|
"os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f45eb6b0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"import datasets \n",
|
|
"\n",
|
|
"from datasets import load_dataset, load_metric\n",
|
|
"\n",
|
|
"from transformers import (\n",
|
|
" AutoModel,\n",
|
|
" AutoModelForMaskedLM,\n",
|
|
" AutoModelForSeq2SeqLM,\n",
|
|
" AutoModelForTokenClassification,\n",
|
|
" AutoTokenizer,\n",
|
|
" DataCollatorForSeq2Seq,\n",
|
|
" pipeline,\n",
|
|
" Seq2SeqTrainingArguments,\n",
|
|
" Seq2SeqTrainer,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b2d26af4",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Load the pre-trained model and tokenizer\n",
|
|
"model_name = \"t5-small\"\n",
|
|
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
|
|
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "363045f5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def preprocess_function(batch):\n",
|
|
" inputs = tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=512)\n",
|
|
" outputs = tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=128)\n",
|
|
" batch[\"input_ids\"] = inputs.input_ids\n",
|
|
" batch[\"attention_mask\"] = inputs.attention_mask\n",
|
|
" batch[\"labels\"] = outputs.input_ids.copy()\n",
|
|
" return batch\n",
|
|
"\n",
|
|
"# Load the dataset\n",
|
|
"train_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
|
|
"val_data = load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:10%]\")\n",
|
|
"\n",
|
|
"train_ds = train_data.map(\n",
|
|
" preprocess_function, \n",
|
|
" batched=True, \n",
|
|
" batch_size=256, \n",
|
|
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
|
")\n",
|
|
"\n",
|
|
"val_ds = val_data.map(\n",
|
|
" preprocess_function, \n",
|
|
" batched=True, \n",
|
|
" batch_size=256, \n",
|
|
" remove_columns=[\"article\", \"highlights\", \"id\"]\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0d58818f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class MyLightningModule(pl.LightningModule):\n",
|
|
" def __init__(self, model_name, learning_rate, weight_decay, batch_size, num_training_steps):\n",
|
|
" super().__init__()\n",
|
|
" self.model_name = model_name\n",
|
|
" self.learning_rate = learning_rate\n",
|
|
" self.weight_decay = weight_decay\n",
|
|
" self.batch_size = batch_size\n",
|
|
" self.num_training_steps = num_training_steps\n",
|
|
" \n",
|
|
" # Load the pre-trained model and tokenizer\n",
|
|
" self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n",
|
|
" self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)\n",
|
|
"\n",
|
|
" def forward(self, input_ids, attention_mask, labels=None):\n",
|
|
" output = self.model(\n",
|
|
" input_ids=input_ids,\n",
|
|
" attention_mask=attention_mask,\n",
|
|
" labels=labels,\n",
|
|
" )\n",
|
|
" return output.loss, output.logits\n",
|
|
" \n",
|
|
" def training_step(self, batch, batch_idx):\n",
|
|
" input_ids = batch[\"input_ids\"]\n",
|
|
" attention_mask = batch[\"attention_mask\"]\n",
|
|
" labels = batch[\"labels\"]\n",
|
|
" \n",
|
|
" loss\n",
|
|
"\n",
|
|
"# Define the data collator\n",
|
|
"data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)\n",
|
|
"\n",
|
|
"# Initialize the trainer arguments\n",
|
|
"training_args = Seq2SeqTrainingArguments(\n",
|
|
" output_dir=\"./results\",\n",
|
|
" learning_rate=1e-5,\n",
|
|
" per_device_train_batch_size=16,\n",
|
|
" per_device_eval_batch_size=16,\n",
|
|
" max_steps=5000,\n",
|
|
" weight_decay=1e-4,\n",
|
|
" push_to_hub=False,\n",
|
|
" evaluation_strategy = \"steps\",\n",
|
|
" eval_steps = 50,\n",
|
|
" generation_max_length=128,\n",
|
|
" predict_with_generate=True,\n",
|
|
" logging_steps=100,\n",
|
|
" gradient_accumulation_steps=1,\n",
|
|
" fp16=True,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Load the ROUGE metric\n",
|
|
"metric = load_metric(\"rouge\")\n",
|
|
"\n",
|
|
"# Define the evaluation function\n",
|
|
"def compute_metrics(pred):\n",
|
|
" labels = pred.label_ids\n",
|
|
" preds = pred.predictions\n",
|
|
" decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n",
|
|
" decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
|
|
" scores = metric.compute(predictions=decoded_preds, references=decoded_labels, rouge_types=[\"rouge1\"])[\"rouge1\"].mid\n",
|
|
" return {\"rouge1_precision\": scores.precision, \"rouge1_recall\": scores.recall, \"rouge1_fmeasure\": scores.fmeasure}\n",
|
|
"\n",
|
|
"\n",
|
|
"# Initialize the trainer\n",
|
|
"trainer = Seq2SeqTrainer(\n",
|
|
" model=model,\n",
|
|
" args=training_args,\n",
|
|
" train_dataset=train_data,\n",
|
|
" eval_dataset=val_data,\n",
|
|
" data_collator=data_collator,\n",
|
|
" tokenizer=tokenizer,\n",
|
|
" compute_metrics=compute_metrics,\n",
|
|
")\n",
|
|
"\n",
|
|
"# Start the training\n",
|
|
"trainer.train()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5148159b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Steps:\n",
|
|
"1. Rewrite code to be more general\n",
|
|
"\n",
|
|
"a) Data loading should be from disk rather than their load_dataset, and should be on the fly\n",
|
|
"\n",
|
|
"b) Rewrite to Lightning code, Trainer etc using Lightning, compute metric fine that we use huggingface"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "95e33e40",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"!nvidia-smi"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4c0348c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.9"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|