mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
change dim=1 to dim=-1
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch version: 2.1.0\n"
|
||||
"torch version: 2.2.1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -1389,19 +1389,19 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[[-0.0844, 0.0414],\n",
|
||||
" [-0.2264, -0.0039],\n",
|
||||
" [-0.4163, -0.0564],\n",
|
||||
" [-0.5014, -0.1011],\n",
|
||||
" [-0.7754, -0.1867],\n",
|
||||
" [-1.1632, -0.3303]],\n",
|
||||
"tensor([[[-0.4519, 0.2216],\n",
|
||||
" [-0.5874, 0.0058],\n",
|
||||
" [-0.6300, -0.0632],\n",
|
||||
" [-0.5675, -0.0843],\n",
|
||||
" [-0.5526, -0.0981],\n",
|
||||
" [-0.5299, -0.1081]],\n",
|
||||
"\n",
|
||||
" [[-0.0844, 0.0414],\n",
|
||||
" [-0.2264, -0.0039],\n",
|
||||
" [-0.4163, -0.0564],\n",
|
||||
" [-0.5014, -0.1011],\n",
|
||||
" [-0.7754, -0.1867],\n",
|
||||
" [-1.1632, -0.3303]]], grad_fn=<UnsafeViewBackward0>)\n",
|
||||
" [[-0.4519, 0.2216],\n",
|
||||
" [-0.5874, 0.0058],\n",
|
||||
" [-0.6300, -0.0632],\n",
|
||||
" [-0.5675, -0.0843],\n",
|
||||
" [-0.5526, -0.0981],\n",
|
||||
" [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)\n",
|
||||
"context_vecs.shape: torch.Size([2, 6, 2])\n"
|
||||
]
|
||||
}
|
||||
@@ -1427,7 +1427,7 @@
|
||||
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
|
||||
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" attn_weights = self.dropout(attn_weights) # New\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
@@ -1496,19 +1496,19 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[[-0.0844, 0.0414, 0.0766, 0.0171],\n",
|
||||
" [-0.2264, -0.0039, 0.2143, 0.1185],\n",
|
||||
" [-0.4163, -0.0564, 0.3878, 0.2453],\n",
|
||||
" [-0.5014, -0.1011, 0.4992, 0.3401],\n",
|
||||
" [-0.7754, -0.1867, 0.7387, 0.4868],\n",
|
||||
" [-1.1632, -0.3303, 1.1224, 0.8460]],\n",
|
||||
"tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],\n",
|
||||
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
|
||||
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
|
||||
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
|
||||
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
|
||||
" [-0.5299, -0.1081, 0.5077, 0.3493]],\n",
|
||||
"\n",
|
||||
" [[-0.0844, 0.0414, 0.0766, 0.0171],\n",
|
||||
" [-0.2264, -0.0039, 0.2143, 0.1185],\n",
|
||||
" [-0.4163, -0.0564, 0.3878, 0.2453],\n",
|
||||
" [-0.5014, -0.1011, 0.4992, 0.3401],\n",
|
||||
" [-0.7754, -0.1867, 0.7387, 0.4868],\n",
|
||||
" [-1.1632, -0.3303, 1.1224, 0.8460]]], grad_fn=<CatBackward0>)\n",
|
||||
" [[-0.4519, 0.2216, 0.4772, 0.1063],\n",
|
||||
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
|
||||
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
|
||||
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
|
||||
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
|
||||
" [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)\n",
|
||||
"context_vecs.shape: torch.Size([2, 6, 4])\n"
|
||||
]
|
||||
}
|
||||
@@ -1559,19 +1559,19 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[[-9.1476e-02, 3.4164e-02],\n",
|
||||
" [-2.6796e-01, -1.3427e-03],\n",
|
||||
" [-4.8421e-01, -4.8909e-02],\n",
|
||||
" [-6.4808e-01, -1.0625e-01],\n",
|
||||
" [-8.8380e-01, -1.7140e-01],\n",
|
||||
" [-1.4744e+00, -3.4327e-01]],\n",
|
||||
"tensor([[[-0.5740, 0.2216],\n",
|
||||
" [-0.7320, 0.0155],\n",
|
||||
" [-0.7774, -0.0546],\n",
|
||||
" [-0.6979, -0.0817],\n",
|
||||
" [-0.6538, -0.0957],\n",
|
||||
" [-0.6424, -0.1065]],\n",
|
||||
"\n",
|
||||
" [[-9.1476e-02, 3.4164e-02],\n",
|
||||
" [-2.6796e-01, -1.3427e-03],\n",
|
||||
" [-4.8421e-01, -4.8909e-02],\n",
|
||||
" [-6.4808e-01, -1.0625e-01],\n",
|
||||
" [-8.8380e-01, -1.7140e-01],\n",
|
||||
" [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
|
||||
" [[-0.5740, 0.2216],\n",
|
||||
" [-0.7320, 0.0155],\n",
|
||||
" [-0.7774, -0.0546],\n",
|
||||
" [-0.6979, -0.0817],\n",
|
||||
" [-0.6538, -0.0957],\n",
|
||||
" [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)\n",
|
||||
"context_vecs.shape: torch.Size([2, 6, 2])\n"
|
||||
]
|
||||
}
|
||||
@@ -1608,7 +1608,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 37,
|
||||
"id": "110b0188-6e9e-4e56-a988-10523c6c8538",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1729,7 +1729,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 38,
|
||||
"id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1772,7 +1772,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 39,
|
||||
"id": "053760f1-1a02-42f0-b3bf-3d939e407039",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1804,7 +1804,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 40,
|
||||
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1814,7 +1814,7 @@
|
||||
"2360064"
|
||||
]
|
||||
},
|
||||
"execution_count": 42,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -1865,7 +1865,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user