From a3a62c509af4b24578aa292b68999b4e84ea24de Mon Sep 17 00:00:00 2001 From: Jestine Paul Date: Sat, 23 Aug 2025 04:12:26 +0800 Subject: [PATCH] Improve MHA einsum (#781) Efficiency update for einsum as mentioned in #772 --- .../mha-implementations.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index e6236d4..ed2a450 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -519,7 +519,7 @@ " scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n", "\n", " # Apply mask\n", - " mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n", + " mask = self.mask[:n, :n]\n", " scores = scores.masked_fill(mask.bool(), -torch.inf)\n", "\n", " # Softmax and dropout\n",