From 3cb5a52a1b1b60580765437191044628d6374dcb Mon Sep 17 00:00:00 2001 From: rasbt Date: Tue, 26 Mar 2024 20:34:50 -0500 Subject: [PATCH] simplify calc_loss_loader --- .../01_main-chapter-code/previous_chapters.py | 5 +- ch05/01_main-chapter-code/ch05.ipynb | 76 +++++-------------- ch05/01_main-chapter-code/train.py | 5 +- .../previous_chapters.py | 5 +- ch05/05_bonus_hparam_tuning/hparam_search.py | 13 ++-- 5 files changed, 33 insertions(+), 71 deletions(-) diff --git a/appendix-D/01_main-chapter-code/previous_chapters.py b/appendix-D/01_main-chapter-code/previous_chapters.py index 14d3660..559c6b9 100644 --- a/appendix-D/01_main-chapter-code/previous_chapters.py +++ b/appendix-D/01_main-chapter-code/previous_chapters.py @@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): - total_loss, batches_seen = 0., 0. + total_loss = 0. if num_batches is None: num_batches = len(data_loader) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() - batches_seen += 1 else: break - return total_loss / batches_seen + return total_loss / num_batches def evaluate_model(model, train_loader, val_loader, device, eval_iter): diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index 0962766..a916ae8 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -764,7 +764,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 23, "id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff", "metadata": {}, "outputs": [], @@ -785,43 +785,6 @@ " text_data = file.read()" ] }, - { - "cell_type": "code", - "execution_count": 18, - "id": "0959c855-f860-4358-8b98-bc654f047578", - "metadata": {}, - "outputs": [], - "source": [ - "from previous_chapters import create_dataloader_v1\n", - "\n", - "# Train/validation ratio\n", - "train_ratio = 0.90\n", - "split_idx = int(train_ratio * len(text_data))\n", - "train_data = text_data[:split_idx]\n", - "val_data = text_data[split_idx:]\n", - "\n", - "\n", - "torch.manual_seed(123)\n", - "\n", - "train_loader = create_dataloader_v1(\n", - " train_data,\n", - " batch_size=2,\n", - " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n", - " stride=GPT_CONFIG_124M[\"ctx_len\"],\n", - " drop_last=True,\n", - " shuffle=True\n", - ")\n", - "\n", - "val_loader = create_dataloader_v1(\n", - " val_data,\n", - " batch_size=2,\n", - " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n", - " stride=GPT_CONFIG_124M[\"ctx_len\"],\n", - " drop_last=False,\n", - " shuffle=False\n", - ")" - ] - }, { "cell_type": "markdown", "id": "379330f1-80f4-4e34-8724-41d892b04cee", @@ -832,7 +795,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 24, "id": "6kgJbe4ehI4q", "metadata": { "colab": { @@ -858,7 +821,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 25, "id": "j2XPde_ThM_e", "metadata": { "colab": { @@ -884,7 +847,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 27, "id": "6b46a952-d50a-4837-af09-4095698f7fd1", "metadata": { "colab": { @@ -940,22 +903,24 @@ }, { "cell_type": "code", - "execution_count": 22, - "id": "fd0e963b-d282-4c97-b004-8772f4b1bd8f", + "execution_count": 35, + "id": "0959c855-f860-4358-8b98-bc654f047578", "metadata": {}, "outputs": [], "source": [ "from previous_chapters import create_dataloader_v1\n", "\n", - "\n", "# Train/validation ratio\n", "train_ratio = 0.90\n", "split_idx = int(train_ratio * len(text_data))\n", + "train_data = text_data[:split_idx]\n", + "val_data = text_data[split_idx:]\n", + "\n", "\n", "torch.manual_seed(123)\n", "\n", "train_loader = create_dataloader_v1(\n", - " text_data[:split_idx],\n", + " train_data,\n", " batch_size=2,\n", " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n", " stride=GPT_CONFIG_124M[\"ctx_len\"],\n", @@ -964,7 +929,7 @@ ")\n", "\n", "val_loader = create_dataloader_v1(\n", - " text_data[split_idx:],\n", + " val_data,\n", " batch_size=2,\n", " max_length=GPT_CONFIG_124M[\"ctx_len\"],\n", " stride=GPT_CONFIG_124M[\"ctx_len\"],\n", @@ -975,7 +940,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 36, "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e", "metadata": {}, "outputs": [], @@ -1012,7 +977,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 37, "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e", "metadata": {}, "outputs": [ @@ -1056,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 38, "id": "eb860488-5453-41d7-9870-23b723f742a0", "metadata": { "colab": { @@ -1101,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 49, "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc", "metadata": { "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc" @@ -1118,17 +1083,16 @@ "\n", "\n", "def calc_loss_loader(data_loader, model, device, num_batches=None):\n", - " total_loss, batches_seen = 0., 0.\n", + " total_loss = 0.\n", " if num_batches is None:\n", " num_batches = len(data_loader)\n", " for i, (input_batch, target_batch) in enumerate(data_loader):\n", " if i < num_batches:\n", " loss = calc_loss_batch(input_batch, target_batch, model, device)\n", " total_loss += loss.item()\n", - " batches_seen += 1\n", " else:\n", " break\n", - " return total_loss / batches_seen" + " return total_loss / num_batches" ] }, { @@ -1142,7 +1106,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 55, "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a", "metadata": {}, "outputs": [ @@ -1150,7 +1114,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Training loss: 11.002305030822754\n", + "Training loss: 10.98758347829183\n", "Validation loss: 10.98110580444336\n" ] } @@ -1159,6 +1123,8 @@ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n", "\n", + "\n", + "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n", "train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n", "val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n", "\n", diff --git a/ch05/01_main-chapter-code/train.py b/ch05/01_main-chapter-code/train.py index 1ecc8d5..fec04be 100644 --- a/ch05/01_main-chapter-code/train.py +++ b/ch05/01_main-chapter-code/train.py @@ -31,17 +31,16 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): - total_loss, batches_seen = 0., 0. + total_loss = 0. if num_batches is None: num_batches = len(data_loader) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() - batches_seen += 1 else: break - return total_loss / batches_seen + return total_loss / num_batches def evaluate_model(model, train_loader, val_loader, device, eval_iter): diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index cf3d31b..f7b9578 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -250,17 +250,16 @@ def calc_loss_batch(input_batch, target_batch, model, device): def calc_loss_loader(data_loader, model, device, num_batches=None): - total_loss, batches_seen = 0., 0. + total_loss = 0. if num_batches is None: num_batches = len(data_loader) for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() - batches_seen += 1 else: break - return total_loss / batches_seen + return total_loss / num_batches def evaluate_model(model, train_loader, val_loader, device, eval_iter): diff --git a/ch05/05_bonus_hparam_tuning/hparam_search.py b/ch05/05_bonus_hparam_tuning/hparam_search.py index 569d7e2..eea44ac 100644 --- a/ch05/05_bonus_hparam_tuning/hparam_search.py +++ b/ch05/05_bonus_hparam_tuning/hparam_search.py @@ -23,18 +23,17 @@ HPARAM_GRID = { } -def calc_loss_loader(data_loader, model, device, num_iters=None): - total_loss, num_batches = 0., 0 - if num_iters is None: - num_iters = len(data_loader) +def calc_loss_loader(data_loader, model, device, num_batches=None): + total_loss = 0. + if num_batches is None: + num_batches = len(data_loader) for i, (input_batch, target_batch) in enumerate(data_loader): - if i < num_iters: + if i < num_batches: loss = calc_loss_batch(input_batch, target_batch, model, device) total_loss += loss.item() - num_batches += 1 else: break - return total_loss + return total_loss / num_batches def calc_loss_batch(input_batch, target_batch, model, device):