mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Update mha-implementations.ipynb
Fix variable spelling in comments to keep consistent with code
This commit is contained in:
@@ -168,7 +168,7 @@
|
||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
|
||||
" queries, keys, values = qkv.unbind(0)\n",
|
||||
"\n",
|
||||
" # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
|
||||
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
|
||||
" attn_scores = queries @ keys.transpose(-2, -1)\n",
|
||||
" attn_scores = attn_scores.masked_fill(\n",
|
||||
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
|
||||
@@ -258,7 +258,7 @@
|
||||
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
|
||||
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
|
||||
"\n",
|
||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
|
||||
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
|
||||
" q, k, v = qkv.unbind(0)\n",
|
||||
"\n",
|
||||
" use_dropout = 0. if not self.training else self.dropout\n",
|
||||
|
||||
Reference in New Issue
Block a user