use smaller number of tokens to emphasize next token prediction goal
@@ -113,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 60,
|
||||
"id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -180,7 +180,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 61,
|
||||
"id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -188,8 +188,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n",
|
||||
" [ 6109, 1110, 6622, 257, 11483, 13]])\n"
|
||||
"tensor([[6109, 3626, 6100, 345],\n",
|
||||
" [6109, 1110, 6622, 257]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -201,8 +201,8 @@
|
||||
"\n",
|
||||
"batch = []\n",
|
||||
"\n",
|
||||
"txt1 = \"Every effort moves you forward.\"\n",
|
||||
"txt2 = \"Every day holds a lesson.\"\n",
|
||||
"txt1 = \"Every effort moves you\"\n",
|
||||
"txt2 = \"Every day holds a\"\n",
|
||||
"\n",
|
||||
"batch.append(torch.tensor(tokenizer.encode(txt1)))\n",
|
||||
"batch.append(torch.tensor(tokenizer.encode(txt2)))\n",
|
||||
@@ -212,7 +212,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 62,
|
||||
"id": "009238cd-0160-4834-979c-309710986bb0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -220,20 +220,16 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output shape: torch.Size([2, 6, 50257])\n",
|
||||
"Output shape: torch.Size([2, 4, 50257])\n",
|
||||
"tensor([[[-1.2034, 0.3201, -0.7130, ..., -1.5548, -0.2390, -0.4667],\n",
|
||||
" [-0.1192, 0.4539, -0.4432, ..., 0.2392, 1.3469, 1.2430],\n",
|
||||
" [ 0.5307, 1.6720, -0.4695, ..., 1.1966, 0.0111, 0.5835],\n",
|
||||
" [ 0.0139, 1.6755, -0.3388, ..., 1.1586, -0.0435, -1.0400],\n",
|
||||
" [ 0.0106, -1.6711, 0.7797, ..., 0.3561, -0.0867, -0.5452],\n",
|
||||
" [ 0.1821, 1.1189, 0.1641, ..., 1.9012, 1.2240, 0.8853]],\n",
|
||||
" [ 0.0139, 1.6755, -0.3388, ..., 1.1586, -0.0435, -1.0400]],\n",
|
||||
"\n",
|
||||
" [[-1.0341, 0.2765, -1.1252, ..., -0.8381, 0.0773, 0.1147],\n",
|
||||
" [-0.2632, 0.5427, -0.2828, ..., 0.1357, 0.3707, 1.3615],\n",
|
||||
" [ 0.9695, 1.2466, -0.3515, ..., -0.0171, -0.3478, 0.2616],\n",
|
||||
" [-0.0237, -0.7329, 0.3184, ..., 1.5946, -0.1334, -0.2981],\n",
|
||||
" [-0.1876, -0.7909, 0.8811, ..., 1.1121, -0.3781, -1.4438],\n",
|
||||
" [ 0.0405, 1.2000, 0.0702, ..., 1.4740, 1.1567, 1.2077]]],\n",
|
||||
" [[-1.0908, 0.1798, -0.9484, ..., -1.6047, 0.2439, -0.4530],\n",
|
||||
" [-0.7860, 0.5581, -0.0610, ..., 0.4835, -0.0077, 1.6621],\n",
|
||||
" [ 0.3567, 1.2698, -0.6398, ..., -0.0162, -0.1296, 0.3717],\n",
|
||||
" [-0.2407, -0.7349, -0.5102, ..., 2.0057, -0.3694, 0.1814]]],\n",
|
||||
" grad_fn=<UnsafeViewBackward0>)\n"
|
||||
]
|
||||
}
|
||||
@@ -283,7 +279,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 25,
|
||||
"id": "79e1b463-dc3f-44ac-9cdb-9d5b6f64eb9d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -318,7 +314,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 26,
|
||||
"id": "9888f79e-8e69-44aa-8a19-cd34292adbf5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -369,7 +365,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 27,
|
||||
"id": "9a1d1bb9-3341-4c9a-bc2a-d2489bf89cda",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -382,8 +378,8 @@
|
||||
" [-0.0189, 0.1121, -1.0876, 1.5173, 0.5647, -1.0876]],\n",
|
||||
" grad_fn=<DivBackward0>)\n",
|
||||
"Mean:\n",
|
||||
" tensor([[2.9802e-08],\n",
|
||||
" [3.9736e-08]], grad_fn=<MeanBackward1>)\n",
|
||||
" tensor([[ 0.0000],\n",
|
||||
" [ 0.0000]], grad_fn=<MeanBackward1>)\n",
|
||||
"Variance:\n",
|
||||
" tensor([[1.],\n",
|
||||
" [1.]], grad_fn=<VarBackward0>)\n"
|
||||
@@ -410,7 +406,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 28,
|
||||
"id": "3e06c34b-c68a-4b36-afbe-b30eda4eca39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -444,7 +440,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 29,
|
||||
"id": "3333a305-aa3d-460a-bcce-b80662d464d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -486,7 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 30,
|
||||
"id": "23b1000a-e613-4b43-bd90-e54deed8d292",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -497,7 +493,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 31,
|
||||
"id": "94c12de2-1cab-46e0-a099-e2e470353bff",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -562,7 +558,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 32,
|
||||
"id": "f84694b7-95f3-4323-b6d6-0a73df278e82",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -580,7 +576,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 33,
|
||||
"id": "fc5487d2-2576-4118-80a7-56c4caac2e71",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -630,7 +626,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 34,
|
||||
"id": "9275c879-b148-4579-a107-86827ca14d4d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -651,7 +647,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 35,
|
||||
"id": "7c4976e2-0261-418e-b042-c5be98c2ccaf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -677,7 +673,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 36,
|
||||
"id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -738,7 +734,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 37,
|
||||
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -796,7 +792,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 38,
|
||||
"id": "c75f43cc-6923-4018-b980-26023086572c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -834,7 +830,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 39,
|
||||
"id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -887,7 +883,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 40,
|
||||
"id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -947,7 +943,33 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 64,
|
||||
"id": "3fb45a63-b1f3-4b08-b525-dafbc8228405",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Input shape: torch.Size([2, 4, 768])\n",
|
||||
"Output shape: torch.Size([2, 4, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"x = torch.rand(2, 4, 768) # Shape: [batch_size, num_tokens, emb_dim]\n",
|
||||
"block = TransformerBlock(GPT_CONFIG_124M)\n",
|
||||
"output = block(x)\n",
|
||||
"\n",
|
||||
"print(\"Input shape:\", x.shape)\n",
|
||||
"print(\"Output shape:\", output.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"id": "01e737a6-fc99-42bb-9f7e-4da899168811",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -955,15 +977,15 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Input shape: torch.Size([2, 6, 768])\n",
|
||||
"Output shape: torch.Size([2, 6, 768])\n"
|
||||
"Input shape: torch.Size([2, 4, 768])\n",
|
||||
"Output shape: torch.Size([2, 4, 768])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"x = torch.rand(2, 6, 768) # Shape: [batch_size, num_tokens, emb_dim]\n",
|
||||
"x = torch.rand(2, 4, 768) # Shape: [batch_size, num_tokens, emb_dim]\n",
|
||||
"block = TransformerBlock(GPT_CONFIG_124M)\n",
|
||||
"output = block(x)\n",
|
||||
"\n",
|
||||
@@ -1014,7 +1036,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 43,
|
||||
"id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1055,7 +1077,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 44,
|
||||
"id": "252b78c2-4404-483b-84fe-a412e55c16fc",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1064,23 +1086,19 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Input batch:\n",
|
||||
" tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n",
|
||||
" [ 6109, 1110, 6622, 257, 11483, 13]])\n",
|
||||
" tensor([[6109, 3626, 6100, 345],\n",
|
||||
" [6109, 1110, 6622, 257]])\n",
|
||||
"\n",
|
||||
"Output shape: torch.Size([2, 6, 50257])\n",
|
||||
"tensor([[[ 0.2237, 0.1153, 0.1121, ..., 0.1412, -0.0542, -0.3782],\n",
|
||||
" [ 0.5285, -0.0155, -0.5074, ..., -0.3225, 0.4875, -0.0612],\n",
|
||||
" [ 0.8632, -0.1178, 0.0481, ..., 0.2388, 0.0922, -0.2874],\n",
|
||||
" [-1.1907, 0.1292, -0.3071, ..., 1.0674, 0.4159, -0.5619],\n",
|
||||
" [ 1.2322, 0.5499, -0.0272, ..., -0.6428, 0.1301, -0.0295],\n",
|
||||
" [-0.4615, 0.1153, 0.2789, ..., -0.3424, 0.8622, -0.9750]],\n",
|
||||
"Output shape: torch.Size([2, 4, 50257])\n",
|
||||
"tensor([[[-0.0055, 0.3224, 0.2185, ..., 0.2539, 0.4578, -0.4747],\n",
|
||||
" [ 0.2663, -0.2975, -0.5040, ..., -0.3903, 0.5328, -0.4224],\n",
|
||||
" [ 1.1146, -0.0923, 0.1303, ..., 0.1521, -0.4494, 0.0276],\n",
|
||||
" [-0.8239, 0.1174, -0.2566, ..., 1.1197, 0.1036, -0.3993]],\n",
|
||||
"\n",
|
||||
" [[-0.0461, -0.0814, -0.2738, ..., 0.2012, 0.0063, -0.5720],\n",
|
||||
" [ 0.1694, -0.2302, 0.0034, ..., 0.8972, 0.2430, -0.0116],\n",
|
||||
" [ 0.9396, 0.9071, -0.2360, ..., 0.7185, 0.5044, -0.1897],\n",
|
||||
" [-0.3008, -0.1149, 0.4390, ..., 1.2587, -0.1521, 0.1293],\n",
|
||||
" [ 1.2862, 0.8138, -0.2298, ..., 0.4084, -0.3298, -0.6869],\n",
|
||||
" [-0.5629, 0.4579, 0.1874, ..., -0.1453, 0.8834, -0.7628]]],\n",
|
||||
" [[-0.1027, 0.1752, -0.1048, ..., 0.2258, 0.1559, -0.8747],\n",
|
||||
" [ 0.2230, 0.1246, 0.0492, ..., 0.8573, -0.2933, 0.3036],\n",
|
||||
" [ 0.9409, 1.3068, -0.1610, ..., 0.8244, 0.1763, 0.0811],\n",
|
||||
" [ 0.4395, 0.2753, 0.1540, ..., 1.3410, -0.3709, 0.1643]]],\n",
|
||||
" grad_fn=<UnsafeViewBackward0>)\n"
|
||||
]
|
||||
}
|
||||
@@ -1106,7 +1124,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 45,
|
||||
"id": "84fb8be4-9d3b-402b-b3da-86b663aac33a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1138,7 +1156,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 46,
|
||||
"id": "e3b43233-e9b8-4f5a-b72b-a263ec686982",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1167,7 +1185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 47,
|
||||
"id": "95a22e02-50d3-48b3-a4e0-d9863343c164",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1196,7 +1214,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 48,
|
||||
"id": "5131a752-fab8-4d70-a600-e29870b33528",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1291,7 +1309,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 49,
|
||||
"id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1345,7 +1363,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 50,
|
||||
"id": "bb3ffc8e-f95f-4a24-a978-939b8953ea3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1359,11 +1377,11 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([ 0.0000, 0.0012, 0.0000, ..., 0.0001, 0.0000,\n",
|
||||
"tensor([ 0.0000, 0.0012, 0.0000, ..., 0.0000, 0.0000,\n",
|
||||
" 0.0000], grad_fn=<SoftmaxBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 27,
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -1380,7 +1398,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 51,
|
||||
"id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1405,7 +1423,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 52,
|
||||
"id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1442,7 +1460,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 53,
|
||||
"id": "053d99f6-5710-4446-8d52-117fb34ea9f6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
||||
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 21 KiB |
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 30 KiB |
|
Before Width: | Height: | Size: 19 KiB After Width: | Height: | Size: 15 KiB |
|
Before Width: | Height: | Size: 28 KiB After Width: | Height: | Size: 26 KiB |