diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index d2bfba2..7cc1058 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -168,7 +168,7 @@ " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", - " # (b, num_head, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(-2, -1)\n", " attn_scores = attn_scores.masked_fill(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", @@ -258,7 +258,7 @@ " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", " q, k, v = qkv.unbind(0)\n", "\n", " use_dropout = 0. if not self.training else self.dropout\n",