use block size variable in positional embedding layer

This commit is contained in:
rasbt
2023-12-28 19:05:06 +01:00
parent 10aa40ba6a
commit 4f161bd549
7 changed files with 125 additions and 40106 deletions

View File

@@ -1593,7 +1593,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
"metadata": {},
"outputs": [],
@@ -74,8 +74,11 @@
"\n",
"vocab_size = 50257\n",
"output_dim = 256\n",
"block_size = 1024\n",
"\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
@@ -83,7 +86,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
"metadata": {},
"outputs": [],
@@ -101,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
"metadata": {},
"outputs": [

View File

@@ -30,7 +30,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
"metadata": {},
"outputs": [
@@ -40,7 +40,7 @@
"[33901]"
]
},
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
@@ -51,7 +51,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
"metadata": {},
"outputs": [
@@ -61,7 +61,7 @@
"[86]"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -72,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "2773c09d-c136-4372-a2be-04b58d292842",
"metadata": {},
"outputs": [
@@ -82,7 +82,7 @@
"[343]"
]
},
"execution_count": 5,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -93,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "8a6abd32-1e0a-4038-9dd2-673f47bcdeb5",
"metadata": {},
"outputs": [
@@ -103,7 +103,7 @@
"[86]"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -114,7 +114,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "26ae940a-9841-4e27-a1df-b83fc8a488b3",
"metadata": {},
"outputs": [
@@ -124,7 +124,7 @@
"[220]"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -135,7 +135,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "a606c39a-6747-4cd8-bb38-e3183f80908d",
"metadata": {},
"outputs": [
@@ -145,7 +145,7 @@
"[959]"
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -156,7 +156,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "47c7268d-8fdc-4957-bc68-5be6113f45a7",
"metadata": {},
"outputs": [
@@ -166,7 +166,7 @@
"'Akwirw ier'"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -185,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 10,
"id": "4d50af16-937b-49e0-8ffd-42d30cbb41c9",
"metadata": {},
"outputs": [],
@@ -239,13 +239,16 @@
"\n",
"vocab_size = 50257\n",
"output_dim = 256\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"max_len = 4\n",
"block_size = max_len\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 11,
"id": "0128eefa-d7c8-4f76-9851-566dfa7c3745",
"metadata": {},
"outputs": [
@@ -258,7 +261,7 @@
" [ 402, 271]])"
]
},
"execution_count": 19,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@@ -275,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 12,
"id": "ff5c1e90-c6de-4a87-adf6-7e19f603291c",
"metadata": {},
"outputs": [
@@ -288,7 +291,7 @@
" [ 402, 271, 10899, 2138, 257, 7026, 15632, 438]])"
]
},
"execution_count": 20,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}