readability improvements

This commit is contained in:
rasbt
2024-01-15 07:36:19 -06:00
parent a7b4880179
commit 9e85f13ba9
2 changed files with 53 additions and 32 deletions

View File

@@ -1658,11 +1658,18 @@
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask twice to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2) \n",
" \n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",