Update mha-implementations.ipynb

Fix variable spelling in comments to keep consistent with code
This commit is contained in:
Rayed Bin Wahed
2024-03-06 23:03:57 +08:00
committed by GitHub
parent b6fe1a37b3
commit 496079c61e

View File

@@ -168,7 +168,7 @@
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(-2, -1)\n",
" attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
@@ -258,7 +258,7 @@
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" q, k, v = qkv.unbind(0)\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",