mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Rename variable to context_length to make it easier on readers (#106)
* rename to context length * fix spacing
This commit is contained in:
committed by
GitHub
parent
a940373a14
commit
2de60d1bfb
@@ -105,7 +105,7 @@
|
||||
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim//12,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
@@ -154,7 +154,7 @@
|
||||
"mha_ch03 = Ch03_MHA(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
@@ -220,13 +220,13 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.context_length = context_length\n",
|
||||
" self.head_dim = d_out // num_heads\n",
|
||||
"\n",
|
||||
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
||||
@@ -234,7 +234,7 @@
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
|
||||
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
@@ -278,7 +278,7 @@
|
||||
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
@@ -321,13 +321,13 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
|
||||
"\n",
|
||||
" self.num_heads = num_heads\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.context_length = context_length\n",
|
||||
" self.head_dim = d_out // num_heads\n",
|
||||
" self.d_out = d_out\n",
|
||||
"\n",
|
||||
@@ -336,7 +336,7 @@
|
||||
" self.dropout = dropout\n",
|
||||
"\n",
|
||||
" self.register_buffer(\n",
|
||||
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
|
||||
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
@@ -388,7 +388,7 @@
|
||||
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
@@ -446,10 +446,10 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class MHAPyTorchClass(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n",
|
||||
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" self.block_size = block_size\n",
|
||||
" self.context_length = context_length\n",
|
||||
" self.multihead_attn = nn.MultiheadAttention(\n",
|
||||
" embed_dim=d_out,\n",
|
||||
" num_heads=num_heads,\n",
|
||||
@@ -461,17 +461,17 @@
|
||||
"\n",
|
||||
" self.need_weights = need_weights\n",
|
||||
" self.proj = nn.Linear(d_out, d_out)\n",
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n",
|
||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" batch_size, num_tokens, _ = x.shape\n",
|
||||
"\n",
|
||||
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
||||
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
||||
" if self.block_size >= num_tokens:\n",
|
||||
" if self.context_length >= num_tokens:\n",
|
||||
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
|
||||
" else:\n",
|
||||
" attn_mask = self.mask[:self.block_size, :self.block_size]\n",
|
||||
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
|
||||
"\n",
|
||||
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
|
||||
" attn_output, _ = self.multihead_attn(\n",
|
||||
@@ -486,7 +486,7 @@
|
||||
"mha_pytorch_class_default = MHAPyTorchClass(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False\n",
|
||||
@@ -548,7 +548,7 @@
|
||||
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
|
||||
" d_in=embed_dim,\n",
|
||||
" d_out=embed_dim,\n",
|
||||
" block_size=context_len,\n",
|
||||
" context_length=context_len,\n",
|
||||
" dropout=0.0,\n",
|
||||
" num_heads=12,\n",
|
||||
" qkv_bias=False,\n",
|
||||
@@ -1031,7 +1031,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user