show normalization explicitely

This commit is contained in:
rasbt
2024-01-06 19:24:01 -05:00
parent ea4b6c4e5f
commit e113075a16

View File

@@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
"metadata": {},
"outputs": [],
@@ -187,7 +187,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
"metadata": {},
"outputs": [
@@ -219,7 +219,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "9842f39b-1654-410e-88bf-d1b899bf0241",
"metadata": {},
"outputs": [
@@ -253,7 +253,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
"metadata": {},
"outputs": [
@@ -284,7 +284,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
"metadata": {},
"outputs": [
@@ -318,7 +318,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
"metadata": {},
"outputs": [
@@ -348,7 +348,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
"metadata": {},
"outputs": [
@@ -407,7 +407,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"id": "04004be8-07a1-468b-ab33-32e16a551b45",
"metadata": {},
"outputs": [
@@ -444,7 +444,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
"metadata": {},
"outputs": [
@@ -476,7 +476,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
"metadata": {},
"outputs": [
@@ -508,7 +508,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
"metadata": {},
"outputs": [
@@ -538,7 +538,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
"metadata": {},
"outputs": [
@@ -570,7 +570,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"id": "2570eb7d-aee1-457a-a61e-7544478219fa",
"metadata": {},
"outputs": [
@@ -649,7 +649,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
"metadata": {},
"outputs": [],
@@ -669,7 +669,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
"metadata": {},
"outputs": [],
@@ -691,7 +691,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"id": "73cedd62-01e1-4196-a575-baecc6095601",
"metadata": {},
"outputs": [
@@ -721,7 +721,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
"metadata": {},
"outputs": [
@@ -760,7 +760,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"id": "64cbc253-a182-4490-a765-246979ea0a28",
"metadata": {},
"outputs": [
@@ -788,7 +788,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"id": "b14e44b5-d170-40f9-8847-8990804af26d",
"metadata": {},
"outputs": [
@@ -824,7 +824,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"id": "146f5587-c845-4e30-9894-c7ed3a248153",
"metadata": {},
"outputs": [
@@ -859,7 +859,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
"metadata": {},
"outputs": [
@@ -894,7 +894,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
"metadata": {},
"outputs": [
@@ -950,7 +950,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"id": "73f411e3-e231-464a-89fe-0a9035e5f839",
"metadata": {},
"outputs": [
@@ -1046,7 +1046,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
"metadata": {},
"outputs": [
@@ -1078,7 +1078,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 28,
"id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
"metadata": {},
"outputs": [
@@ -1097,8 +1097,8 @@
],
"source": [
"block_size = attn_scores.shape[0]\n",
"mask_naive = torch.tril(torch.ones(block_size, block_size))\n",
"print(mask_naive)"
"mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
"print(mask_simple)"
]
},
{
@@ -1111,7 +1111,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 29,
"id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
"metadata": {},
"outputs": [
@@ -1129,8 +1129,8 @@
}
],
"source": [
"masked_naive = attn_weights*mask_naive\n",
"print(masked_naive)"
"masked_simple = attn_weights*mask_simple\n",
"print(masked_simple)"
]
},
{
@@ -1141,12 +1141,46 @@
"- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax. Softmax ensures that all output values sum to 1. Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects."
]
},
{
"cell_type": "markdown",
"id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
"metadata": {},
"source": [
"- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.4056, 0.5944, 0.0000, 0.0000, 0.0000, 0.0000],\n",
" [0.2566, 0.3741, 0.3693, 0.0000, 0.0000, 0.0000],\n",
" [0.2176, 0.2823, 0.2796, 0.2205, 0.0000, 0.0000],\n",
" [0.1826, 0.2178, 0.2191, 0.1689, 0.2115, 0.0000],\n",
" [0.1473, 0.2033, 0.1996, 0.1500, 0.1160, 0.1839]])\n"
]
}
],
"source": [
"row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
"masked_simple_norm = masked_simple / row_sums\n",
"print(masked_simple_norm)"
]
},
{
"cell_type": "markdown",
"id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
"metadata": {},
"source": [
"- So, instead, we take a different approach, masking elements with negative infinity before they enter the softmax function:"
"- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above.\n",
"- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
]
},
{