Rename variable to context_length to make it easier on readers (#106)

* rename to context length

* fix spacing
This commit is contained in:
Sebastian Raschka
2024-04-04 07:27:41 -05:00
committed by GitHub
parent 684562733a
commit ccd7cebbb3
25 changed files with 242 additions and 242 deletions

View File

@@ -1275,8 +1275,8 @@
}
],
"source": [
"block_size = attn_scores.shape[0]\n",
"mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
"context_length = attn_scores.shape[0]\n",
"mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
"print(mask_simple)"
]
},
@@ -1395,7 +1395,7 @@
}
],
"source": [
"mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
"mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
"masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
"print(masked)"
]
@@ -1598,14 +1598,14 @@
"source": [
"class CausalAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
@@ -1624,8 +1624,8 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"block_size = batch.shape[1]\n",
"ca = CausalAttention(d_in, d_out, block_size, 0.0)\n",
"context_length = batch.shape[1]\n",
"ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
"\n",
"context_vecs = ca(batch)\n",
"\n",
@@ -1713,10 +1713,10 @@
"source": [
"class MultiHeadAttentionWrapper(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
"\n",
@@ -1726,9 +1726,9 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"block_size = batch.shape[1] # This is the number of tokens\n",
"context_length = batch.shape[1] # This is the number of tokens\n",
"d_in, d_out = 3, 2\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",
@@ -1792,7 +1792,7 @@
],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
@@ -1805,7 +1805,7 @@
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
@@ -1848,9 +1848,9 @@
"\n",
"torch.manual_seed(123)\n",
"\n",
"batch_size, block_size, d_in = batch.shape\n",
"batch_size, context_length, d_in = batch.shape\n",
"d_out = 2\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",

View File

@@ -201,7 +201,7 @@
"torch.manual_seed(123)\n",
"\n",
"d_out = 1\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"context_vecs = mha(batch)\n",
"\n",
@@ -247,11 +247,11 @@
"metadata": {},
"source": [
"```python\n",
"block_size = 1024\n",
"context_length = 1024\n",
"d_in, d_out = 768, 768\n",
"num_heads = 12\n",
"\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads)\n",
"```"
]
},

View File

@@ -116,11 +116,11 @@
"vocab_size = 50257\n",
"output_dim = 256\n",
"max_len = 1024\n",
"block_size = max_len\n",
"context_length = max_len\n",
"\n",
"\n",
"token_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=max_length)"
@@ -187,14 +187,14 @@
"source": [
"class CausalSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, n_tokens, d_in = x.shape # New batch dimension b\n",
@@ -213,10 +213,10 @@
"\n",
"\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
" [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
@@ -243,13 +243,13 @@
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"context_length = max_length\n",
"d_in = output_dim\n",
"\n",
"num_heads=2\n",
"d_out = d_in // num_heads\n",
"\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
@@ -273,7 +273,7 @@
"outputs": [],
"source": [
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
@@ -286,7 +286,7 @@
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
@@ -345,11 +345,11 @@
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"context_length = max_length\n",
"d_in = output_dim\n",
"d_out = d_in\n",
"\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
@@ -374,7 +374,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
}
},
"nbformat": 4,