From de576296de828d159e3f6fb69bb596c9ed4f3a33 Mon Sep 17 00:00:00 2001 From: rasbt Date: Mon, 25 Mar 2024 08:09:31 -0500 Subject: [PATCH] simplify .view code --- .../01_main-chapter-code/previous_chapters.py | 4 +- ch05/01_main-chapter-code/ch05.ipynb | 118 +++++++++--------- ch05/01_main-chapter-code/train.py | 4 +- .../previous_chapters.py | 4 +- 4 files changed, 60 insertions(+), 70 deletions(-) diff --git a/appendix-D/01_main-chapter-code/previous_chapters.py b/appendix-D/01_main-chapter-code/previous_chapters.py index 719d791..14d3660 100644 --- a/appendix-D/01_main-chapter-code/previous_chapters.py +++ b/appendix-D/01_main-chapter-code/previous_chapters.py @@ -250,10 +250,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch) - logits = logits.view(-1, logits.size(-1)) - loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1)) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) return loss diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index 745c225..fc6df9d 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -18,7 +18,7 @@ "id": "66dd524e-864c-4012-b0a2-ccfc56e80024" }, "source": [ - "# Chapter 5: Pretraining on Unlabeled Data\n" + "# Chapter 5: Pretraining on Unlabeled Data" ] }, { @@ -260,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f", "metadata": { "colab": { @@ -290,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6", "metadata": {}, "outputs": [ @@ -345,7 +345,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "34ebd76a-16ec-4c17-8958-8a135735cc1c", "metadata": { "colab": { @@ -385,23 +385,22 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "5e777e53-0a1d-4929-bcd5-b0ea2469bac6", + "execution_count": 7, + "id": "c990ead6-53cd-49a7-a6d1-14d8c1518249", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "' Armed heNetflix pressuring empoweredfaith'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "Targets batch 1: effort moves you\n", + "Outputs batch 1: Armed heNetflix\n" + ] } ], "source": [ - "print(token_ids_to_text(token_ids.flatten(), tokenizer))" + "print(f\"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}\")\n", + "print(f\"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}\")" ] }, { @@ -426,12 +425,12 @@ "id": "c7251bf5-a079-4782-901d-68c9225d3157", "metadata": {}, "source": [ - "- In the first input batch, the token probabilities corresponding to the target indices are as follows:" + "- The token probabilities corresponding to the target indices are as follows:" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 19, "id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a", "metadata": { "colab": { @@ -445,48 +444,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n" + "Batch 1: tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n", + "Batch 2: tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n" ] } ], "source": [ "batch_idx = 0\n", "target_probas_1 = probas[batch_idx, [0, 1, 2], targets[batch_idx]]\n", - "print(target_probas_1)" - ] - }, - { - "cell_type": "markdown", - "id": "5305e2af-e973-4fc3-b717-a065a0f6ceec", - "metadata": {}, - "source": [ - "- And in the second input batch, the token probabilities corresponding to the target indices are as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d", - "outputId": "5a439322-6e4f-49e9-b462-6d28b92415ab" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n" - ] - } - ], - "source": [ + "print(\"Batch 1:\", target_probas_1)\n", + "\n", "batch_idx = 1\n", "target_probas_2 = probas[1, [0, 1, 2], targets[1]]\n", - "print(target_probas_2)" + "print(\"Batch 2:\", target_probas_2)" ] }, { @@ -500,7 +470,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 22, "id": "31402a67-a16e-4aeb-977e-70abb9c9949b", "metadata": { "colab": { @@ -534,7 +504,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 23, "id": "9b003797-161b-4d98-81dc-e68320e09fec", "metadata": { "colab": { @@ -573,7 +543,33 @@ "metadata": {}, "source": [ "- In deep learning, instead of maximizing the average log-probability, it's a standard convention to minimize the *negative* average log-probability value; in our case, instead of maximizing -10.7722 so that it approaches 0, in deep learning, we would minimize 10.7722 so that it approaches 0\n", - "- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning\n", + "- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(10.7722)\n" + ] + } + ], + "source": [ + "neg_avg_log_probas = avg_log_probas * -1\n", + "print(neg_avg_log_probas)" + ] + }, + { + "cell_type": "markdown", + "id": "84eeb868-abd8-4028-82db-107546bf7c2c", + "metadata": {}, + "source": [ "- PyTorch already implements a `cross_entropy` function that carries out the previous steps" ] }, @@ -615,7 +611,7 @@ } ], "source": [ - "# Logits have shape (batch_size, num_tokens, vocab_size\n", + "# Logits have shape (batch_size, num_tokens, vocab_size)\n", "print(\"Logits shape:\", logits.shape)\n", "\n", "# Targets have shape (batch_size, num_tokens)\n", @@ -652,10 +648,10 @@ } ], "source": [ - "logits_flat = logits.view(-1, logits.shape[-1])\n", - "print(\"Flattened logits:\", logits_flat.shape)\n", + "logits_flat = logits.flatten(0, 1)\n", + "targets_flat = targets.flatten()\n", "\n", - "targets_flat = targets.view(-1)\n", + "print(\"Flattened logits:\", logits_flat.shape)\n", "print(\"Flattened targets:\", targets_flat.shape)" ] }, @@ -1114,8 +1110,8 @@ " input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", "\n", " logits = model(input_batch)\n", - " logits = logits.view(-1, logits.size(-1))\n", - " loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))\n", + " logits = logits.flatten(0, 1)\n", + " loss = torch.nn.functional.cross_entropy(logits, target_batch.flatten())\n", " return loss\n", "\n", "\n", @@ -2496,7 +2492,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/ch05/01_main-chapter-code/train.py b/ch05/01_main-chapter-code/train.py index 2bf7065..1ecc8d5 100644 --- a/ch05/01_main-chapter-code/train.py +++ b/ch05/01_main-chapter-code/train.py @@ -25,10 +25,8 @@ def token_ids_to_text(token_ids, tokenizer): def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch) - logits = logits.view(-1, logits.size(-1)) - loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1)) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten()) return loss diff --git a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py index 1887c8b..3fcf0d0 100644 --- a/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py +++ b/ch05/03_bonus_pretraining_on_gutenberg/previous_chapters.py @@ -244,10 +244,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size): def calc_loss_batch(input_batch, target_batch, model, device): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch) - logits = logits.view(-1, logits.size(-1)) - loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1)) + loss = torch.nn.functional.cross_entropy(logits.flatten(0, -1), target_batch.flatten()) return loss