From d4754f1bdda95033ad3750ca928e1e5524936dad Mon Sep 17 00:00:00 2001 From: rasbt Date: Mon, 4 Mar 2024 18:54:43 -0600 Subject: [PATCH] change dim=1 to dim=-1 --- ch03/01_main-chapter-code/ch03.ipynb | 88 +++++++++---------- .../multihead-attention.ipynb | 12 +-- 2 files changed, 46 insertions(+), 54 deletions(-) diff --git a/ch03/01_main-chapter-code/ch03.ipynb b/ch03/01_main-chapter-code/ch03.ipynb index cae40f0..614ae08 100644 --- a/ch03/01_main-chapter-code/ch03.ipynb +++ b/ch03/01_main-chapter-code/ch03.ipynb @@ -26,7 +26,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "torch version: 2.1.0\n" + "torch version: 2.2.1\n" ] } ], @@ -1389,19 +1389,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[-0.0844, 0.0414],\n", - " [-0.2264, -0.0039],\n", - " [-0.4163, -0.0564],\n", - " [-0.5014, -0.1011],\n", - " [-0.7754, -0.1867],\n", - " [-1.1632, -0.3303]],\n", + "tensor([[[-0.4519, 0.2216],\n", + " [-0.5874, 0.0058],\n", + " [-0.6300, -0.0632],\n", + " [-0.5675, -0.0843],\n", + " [-0.5526, -0.0981],\n", + " [-0.5299, -0.1081]],\n", "\n", - " [[-0.0844, 0.0414],\n", - " [-0.2264, -0.0039],\n", - " [-0.4163, -0.0564],\n", - " [-0.5014, -0.1011],\n", - " [-0.7754, -0.1867],\n", - " [-1.1632, -0.3303]]], grad_fn=)\n", + " [[-0.4519, 0.2216],\n", + " [-0.5874, 0.0058],\n", + " [-0.6300, -0.0632],\n", + " [-0.5675, -0.0843],\n", + " [-0.5526, -0.0981],\n", + " [-0.5299, -0.1081]]], grad_fn=)\n", "context_vecs.shape: torch.Size([2, 6, 2])\n" ] } @@ -1427,7 +1427,7 @@ " attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n", " attn_scores.masked_fill_( # New, _ ops are in-place\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = self.dropout(attn_weights) # New\n", "\n", " context_vec = attn_weights @ values\n", @@ -1496,19 +1496,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[-0.0844, 0.0414, 0.0766, 0.0171],\n", - " [-0.2264, -0.0039, 0.2143, 0.1185],\n", - " [-0.4163, -0.0564, 0.3878, 0.2453],\n", - " [-0.5014, -0.1011, 0.4992, 0.3401],\n", - " [-0.7754, -0.1867, 0.7387, 0.4868],\n", - " [-1.1632, -0.3303, 1.1224, 0.8460]],\n", + "tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],\n", + " [-0.5874, 0.0058, 0.5891, 0.3257],\n", + " [-0.6300, -0.0632, 0.6202, 0.3860],\n", + " [-0.5675, -0.0843, 0.5478, 0.3589],\n", + " [-0.5526, -0.0981, 0.5321, 0.3428],\n", + " [-0.5299, -0.1081, 0.5077, 0.3493]],\n", "\n", - " [[-0.0844, 0.0414, 0.0766, 0.0171],\n", - " [-0.2264, -0.0039, 0.2143, 0.1185],\n", - " [-0.4163, -0.0564, 0.3878, 0.2453],\n", - " [-0.5014, -0.1011, 0.4992, 0.3401],\n", - " [-0.7754, -0.1867, 0.7387, 0.4868],\n", - " [-1.1632, -0.3303, 1.1224, 0.8460]]], grad_fn=)\n", + " [[-0.4519, 0.2216, 0.4772, 0.1063],\n", + " [-0.5874, 0.0058, 0.5891, 0.3257],\n", + " [-0.6300, -0.0632, 0.6202, 0.3860],\n", + " [-0.5675, -0.0843, 0.5478, 0.3589],\n", + " [-0.5526, -0.0981, 0.5321, 0.3428],\n", + " [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=)\n", "context_vecs.shape: torch.Size([2, 6, 4])\n" ] } @@ -1559,19 +1559,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[[-9.1476e-02, 3.4164e-02],\n", - " [-2.6796e-01, -1.3427e-03],\n", - " [-4.8421e-01, -4.8909e-02],\n", - " [-6.4808e-01, -1.0625e-01],\n", - " [-8.8380e-01, -1.7140e-01],\n", - " [-1.4744e+00, -3.4327e-01]],\n", + "tensor([[[-0.5740, 0.2216],\n", + " [-0.7320, 0.0155],\n", + " [-0.7774, -0.0546],\n", + " [-0.6979, -0.0817],\n", + " [-0.6538, -0.0957],\n", + " [-0.6424, -0.1065]],\n", "\n", - " [[-9.1476e-02, 3.4164e-02],\n", - " [-2.6796e-01, -1.3427e-03],\n", - " [-4.8421e-01, -4.8909e-02],\n", - " [-6.4808e-01, -1.0625e-01],\n", - " [-8.8380e-01, -1.7140e-01],\n", - " [-1.4744e+00, -3.4327e-01]]], grad_fn=)\n", + " [[-0.5740, 0.2216],\n", + " [-0.7320, 0.0155],\n", + " [-0.7774, -0.0546],\n", + " [-0.6979, -0.0817],\n", + " [-0.6538, -0.0957],\n", + " [-0.6424, -0.1065]]], grad_fn=)\n", "context_vecs.shape: torch.Size([2, 6, 2])\n" ] } @@ -1608,7 +1608,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 37, "id": "110b0188-6e9e-4e56-a988-10523c6c8538", "metadata": {}, "outputs": [ @@ -1729,7 +1729,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 38, "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9", "metadata": {}, "outputs": [ @@ -1772,7 +1772,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 39, "id": "053760f1-1a02-42f0-b3bf-3d939e407039", "metadata": {}, "outputs": [ @@ -1804,7 +1804,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 40, "id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937", "metadata": {}, "outputs": [ @@ -1814,7 +1814,7 @@ "2360064" ] }, - "execution_count": 42, + "execution_count": 40, "metadata": {}, "output_type": "execute_result" } @@ -1865,7 +1865,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/ch03/01_main-chapter-code/multihead-attention.ipynb b/ch03/01_main-chapter-code/multihead-attention.ipynb index 857316f..c981b2b 100644 --- a/ch03/01_main-chapter-code/multihead-attention.ipynb +++ b/ch03/01_main-chapter-code/multihead-attention.ipynb @@ -173,7 +173,7 @@ " attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n", " attn_scores.masked_fill_( # New, _ ops are in-place\n", " self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = self.dropout(attn_weights) # New\n", "\n", " context_vec = attn_weights @ values\n", @@ -324,14 +324,6 @@ "\n", "print(\"context_vecs.shape:\", context_vecs.shape)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8d4be84-28bb-41d5-996c-4936acffd411", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -350,7 +342,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.11.4" } }, "nbformat": 4,