remove redundant double-unsequeeze

This commit is contained in:
rasbt
2024-02-29 08:31:07 -06:00
parent d89aaf319d
commit b827bf4eea
4 changed files with 13 additions and 21 deletions

View File

@@ -91,8 +91,8 @@ class MultiHeadAttention(nn.Module):
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask twice to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
# Unsqueeze the mask to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0)
# Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)

View File

@@ -80,8 +80,8 @@ class MultiHeadAttention(nn.Module):
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Unsqueeze the mask twice to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
# Unsqueeze the mask to match dimensions
mask_unsqueezed = mask_bool.unsqueeze(0)
# Use the unsqueezed mask to fill attention scores
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)