change dim=1 to dim=-1

This commit is contained in:
rasbt
2024-03-04 18:54:43 -06:00
parent b50c42ffbb
commit d4754f1bdd
2 changed files with 46 additions and 54 deletions

View File

@@ -26,7 +26,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.1.0\n"
"torch version: 2.2.1\n"
]
}
],
@@ -1389,19 +1389,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-0.0844, 0.0414],\n",
" [-0.2264, -0.0039],\n",
" [-0.4163, -0.0564],\n",
" [-0.5014, -0.1011],\n",
" [-0.7754, -0.1867],\n",
" [-1.1632, -0.3303]],\n",
"tensor([[[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]],\n",
"\n",
" [[-0.0844, 0.0414],\n",
" [-0.2264, -0.0039],\n",
" [-0.4163, -0.0564],\n",
" [-0.5014, -0.1011],\n",
" [-0.7754, -0.1867],\n",
" [-1.1632, -0.3303]]], grad_fn=<UnsafeViewBackward0>)\n",
" [[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)\n",
"context_vecs.shape: torch.Size([2, 6, 2])\n"
]
}
@@ -1427,7 +1427,7 @@
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights) # New\n",
"\n",
" context_vec = attn_weights @ values\n",
@@ -1496,19 +1496,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-0.0844, 0.0414, 0.0766, 0.0171],\n",
" [-0.2264, -0.0039, 0.2143, 0.1185],\n",
" [-0.4163, -0.0564, 0.3878, 0.2453],\n",
" [-0.5014, -0.1011, 0.4992, 0.3401],\n",
" [-0.7754, -0.1867, 0.7387, 0.4868],\n",
" [-1.1632, -0.3303, 1.1224, 0.8460]],\n",
"tensor([[[-0.4519, 0.2216, 0.4772, 0.1063],\n",
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
" [-0.5299, -0.1081, 0.5077, 0.3493]],\n",
"\n",
" [[-0.0844, 0.0414, 0.0766, 0.0171],\n",
" [-0.2264, -0.0039, 0.2143, 0.1185],\n",
" [-0.4163, -0.0564, 0.3878, 0.2453],\n",
" [-0.5014, -0.1011, 0.4992, 0.3401],\n",
" [-0.7754, -0.1867, 0.7387, 0.4868],\n",
" [-1.1632, -0.3303, 1.1224, 0.8460]]], grad_fn=<CatBackward0>)\n",
" [[-0.4519, 0.2216, 0.4772, 0.1063],\n",
" [-0.5874, 0.0058, 0.5891, 0.3257],\n",
" [-0.6300, -0.0632, 0.6202, 0.3860],\n",
" [-0.5675, -0.0843, 0.5478, 0.3589],\n",
" [-0.5526, -0.0981, 0.5321, 0.3428],\n",
" [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=<CatBackward0>)\n",
"context_vecs.shape: torch.Size([2, 6, 4])\n"
]
}
@@ -1559,19 +1559,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[-9.1476e-02, 3.4164e-02],\n",
" [-2.6796e-01, -1.3427e-03],\n",
" [-4.8421e-01, -4.8909e-02],\n",
" [-6.4808e-01, -1.0625e-01],\n",
" [-8.8380e-01, -1.7140e-01],\n",
" [-1.4744e+00, -3.4327e-01]],\n",
"tensor([[[-0.5740, 0.2216],\n",
" [-0.7320, 0.0155],\n",
" [-0.7774, -0.0546],\n",
" [-0.6979, -0.0817],\n",
" [-0.6538, -0.0957],\n",
" [-0.6424, -0.1065]],\n",
"\n",
" [[-9.1476e-02, 3.4164e-02],\n",
" [-2.6796e-01, -1.3427e-03],\n",
" [-4.8421e-01, -4.8909e-02],\n",
" [-6.4808e-01, -1.0625e-01],\n",
" [-8.8380e-01, -1.7140e-01],\n",
" [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
" [[-0.5740, 0.2216],\n",
" [-0.7320, 0.0155],\n",
" [-0.7774, -0.0546],\n",
" [-0.6979, -0.0817],\n",
" [-0.6538, -0.0957],\n",
" [-0.6424, -0.1065]]], grad_fn=<CatBackward0>)\n",
"context_vecs.shape: torch.Size([2, 6, 2])\n"
]
}
@@ -1608,7 +1608,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 37,
"id": "110b0188-6e9e-4e56-a988-10523c6c8538",
"metadata": {},
"outputs": [
@@ -1729,7 +1729,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 38,
"id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
"metadata": {},
"outputs": [
@@ -1772,7 +1772,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 39,
"id": "053760f1-1a02-42f0-b3bf-3d939e407039",
"metadata": {},
"outputs": [
@@ -1804,7 +1804,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 40,
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
"metadata": {},
"outputs": [
@@ -1814,7 +1814,7 @@
"2360064"
]
},
"execution_count": 42,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
@@ -1865,7 +1865,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -173,7 +173,7 @@
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights) # New\n",
"\n",
" context_vec = attn_weights @ values\n",
@@ -324,14 +324,6 @@
"\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8d4be84-28bb-41d5-996c-4936acffd411",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -350,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,