remove redundant unsqueeze in mask

This commit is contained in:
rasbt
2024-03-09 17:42:25 -06:00
parent 6ba97adaee
commit da33ce8054
7 changed files with 45 additions and 37 deletions

View File

@@ -89,12 +89,12 @@ class MultiHeadAttention(nn.Module):
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask twice to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
# Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)