mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve MHA einsum (#781)
Efficiency update for einsum as mentioned in #772
This commit is contained in:
@@ -519,7 +519,7 @@
|
||||
" scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
|
||||
"\n",
|
||||
" # Apply mask\n",
|
||||
" mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
|
||||
" mask = self.mask[:n, :n]\n",
|
||||
" scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
|
||||
"\n",
|
||||
" # Softmax and dropout\n",
|
||||
|
||||
Reference in New Issue
Block a user