mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
remove redundant double-unsequeeze
This commit is contained in:
@@ -91,8 +91,8 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
# Unsqueeze the mask twice to match dimensions
|
||||
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
|
||||
# Unsqueeze the mask to match dimensions
|
||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
||||
# Use the unsqueezed mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
||||
|
||||
|
||||
@@ -80,8 +80,8 @@ class MultiHeadAttention(nn.Module):
|
||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||
# Original mask truncated to the number of tokens and converted to boolean
|
||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||
# Unsqueeze the mask twice to match dimensions
|
||||
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
|
||||
# Unsqueeze the mask to match dimensions
|
||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
||||
# Use the unsqueezed mask to fill attention scores
|
||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user