mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
show normalization explicitely
This commit is contained in:
@@ -159,7 +159,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 2,
|
||||
"id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -187,7 +187,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -219,7 +219,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 4,
|
||||
"id": "9842f39b-1654-410e-88bf-d1b899bf0241",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -253,7 +253,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 5,
|
||||
"id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -284,7 +284,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 6,
|
||||
"id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -318,7 +318,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 7,
|
||||
"id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -348,7 +348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 8,
|
||||
"id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -407,7 +407,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 9,
|
||||
"id": "04004be8-07a1-468b-ab33-32e16a551b45",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -444,7 +444,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 10,
|
||||
"id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -476,7 +476,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 11,
|
||||
"id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -508,7 +508,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 12,
|
||||
"id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -538,7 +538,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -570,7 +570,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"id": "2570eb7d-aee1-457a-a61e-7544478219fa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -649,7 +649,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -669,7 +669,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -691,7 +691,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 17,
|
||||
"id": "73cedd62-01e1-4196-a575-baecc6095601",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -721,7 +721,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 18,
|
||||
"id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -760,7 +760,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 19,
|
||||
"id": "64cbc253-a182-4490-a765-246979ea0a28",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -788,7 +788,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 20,
|
||||
"id": "b14e44b5-d170-40f9-8847-8990804af26d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -824,7 +824,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"id": "146f5587-c845-4e30-9894-c7ed3a248153",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -859,7 +859,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 22,
|
||||
"id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -894,7 +894,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 23,
|
||||
"id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -950,7 +950,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 24,
|
||||
"id": "73f411e3-e231-464a-89fe-0a9035e5f839",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1046,7 +1046,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 25,
|
||||
"id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1078,7 +1078,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 28,
|
||||
"id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1097,8 +1097,8 @@
|
||||
],
|
||||
"source": [
|
||||
"block_size = attn_scores.shape[0]\n",
|
||||
"mask_naive = torch.tril(torch.ones(block_size, block_size))\n",
|
||||
"print(mask_naive)"
|
||||
"mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
|
||||
"print(mask_simple)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1111,7 +1111,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 29,
|
||||
"id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1129,8 +1129,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"masked_naive = attn_weights*mask_naive\n",
|
||||
"print(masked_naive)"
|
||||
"masked_simple = attn_weights*mask_simple\n",
|
||||
"print(masked_simple)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1141,12 +1141,46 @@
|
||||
"- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax. Softmax ensures that all output values sum to 1. Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.4056, 0.5944, 0.0000, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2566, 0.3741, 0.3693, 0.0000, 0.0000, 0.0000],\n",
|
||||
" [0.2176, 0.2823, 0.2796, 0.2205, 0.0000, 0.0000],\n",
|
||||
" [0.1826, 0.2178, 0.2191, 0.1689, 0.2115, 0.0000],\n",
|
||||
" [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
|
||||
"masked_simple_norm = masked_simple / row_sums\n",
|
||||
"print(masked_simple_norm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- So, instead, we take a different approach, masking elements with negative infinity before they enter the softmax function:"
|
||||
"- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above.\n",
|
||||
"- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user