diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index e6236d4..ed2a450 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -519,7 +519,7 @@ " scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n", "\n", " # Apply mask\n", - " mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n", + " mask = self.mask[:n, :n]\n", " scores = scores.masked_fill(mask.bool(), -torch.inf)\n", "\n", " # Softmax and dropout\n",