mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Rename variable to context_length to make it easier on readers (#106)
* rename to context length * fix spacing
This commit is contained in:
committed by
GitHub
parent
a940373a14
commit
2de60d1bfb
@@ -1275,8 +1275,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"block_size = attn_scores.shape[0]\n",
|
||||
"mask_simple = torch.tril(torch.ones(block_size, block_size))\n",
|
||||
"context_length = attn_scores.shape[0]\n",
|
||||
"mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
|
||||
"print(mask_simple)"
|
||||
]
|
||||
},
|
||||
@@ -1395,7 +1395,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"mask = torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
|
||||
"mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
"masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
|
||||
"print(masked)"
|
||||
]
|
||||
@@ -1598,14 +1598,14 @@
|
||||
"source": [
|
||||
"class CausalAttention(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.dropout = nn.Dropout(dropout) # New\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
|
||||
@@ -1624,8 +1624,8 @@
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"block_size = batch.shape[1]\n",
|
||||
"ca = CausalAttention(d_in, d_out, block_size, 0.0)\n",
|
||||
"context_length = batch.shape[1]\n",
|
||||
"ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
|
||||
"\n",
|
||||
"context_vecs = ca(batch)\n",
|
||||
"\n",
|
||||
@@ -1713,10 +1713,10 @@
|
||||
"source": [
|
||||
"class MultiHeadAttentionWrapper(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" self.heads = nn.ModuleList(\n",
|
||||
" [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) \n",
|
||||
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
|
||||
" for _ in range(num_heads)]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@@ -1726,9 +1726,9 @@
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"block_size = batch.shape[1] # This is the number of tokens\n",
|
||||
"context_length = batch.shape[1] # This is the number of tokens\n",
|
||||
"d_in, d_out = 3, 2\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
|
||||
"\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
"\n",
|
||||
@@ -1792,7 +1792,7 @@
|
||||
],
|
||||
"source": [
|
||||
"class MultiHeadAttention(nn.Module):\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):\n",
|
||||
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
|
||||
" super().__init__()\n",
|
||||
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
|
||||
"\n",
|
||||
@@ -1805,7 +1805,7 @@
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
|
||||
" self.dropout = nn.Dropout(dropout)\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
@@ -1848,9 +1848,9 @@
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"batch_size, block_size, d_in = batch.shape\n",
|
||||
"batch_size, context_length, d_in = batch.shape\n",
|
||||
"d_out = 2\n",
|
||||
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
||||
"mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
|
||||
"\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
"\n",
|
||||
|
||||
Reference in New Issue
Block a user