Improve MHA einsum (#781)

Efficiency update for einsum as mentioned in #772
This commit is contained in:
Jestine Paul
2025-08-23 04:12:26 +08:00
committed by GitHub
parent 670f7a4dd0
commit a3a62c509a

View File

@@ -519,7 +519,7 @@
" scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
"\n",
" # Apply mask\n",
" mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
" mask = self.mask[:n, :n]\n",
" scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
"\n",
" # Softmax and dropout\n",