mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
remove redundant unsqueeze in mask
This commit is contained in:
@@ -1608,7 +1608,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 37,
|
"execution_count": 42,
|
||||||
"id": "110b0188-6e9e-4e56-a988-10523c6c8538",
|
"id": "110b0188-6e9e-4e56-a988-10523c6c8538",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1670,12 +1670,12 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
||||||
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
||||||
|
"\n",
|
||||||
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
||||||
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
||||||
" # Unsqueeze the mask to match dimensions\n",
|
"\n",
|
||||||
" mask_unsqueezed = mask_bool.unsqueeze(0)\n",
|
" # Use the mask to fill attention scores\n",
|
||||||
" # Use the unsqueezed mask to fill attention scores\n",
|
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
|
||||||
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||||
" attn_weights = self.dropout(attn_weights)\n",
|
" attn_weights = self.dropout(attn_weights)\n",
|
||||||
@@ -1865,7 +1865,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -148,7 +148,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
|
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -196,7 +196,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -235,7 +235,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 10,
|
||||||
"id": "2773c09d-c136-4372-a2be-04b58d292842",
|
"id": "2773c09d-c136-4372-a2be-04b58d292842",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -276,12 +276,12 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
||||||
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
||||||
|
" \n",
|
||||||
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
||||||
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
||||||
" # Unsqueeze the mask to match dimensions\n",
|
"\n",
|
||||||
" mask_unsqueezed = mask_bool.unsqueeze(0)\n",
|
" # Use the mask to fill attention scores\n",
|
||||||
" # Use the unsqueezed mask to fill attention scores\n",
|
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
|
||||||
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||||
" attn_weights = self.dropout(attn_weights)\n",
|
" attn_weights = self.dropout(attn_weights)\n",
|
||||||
@@ -298,7 +298,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 11,
|
||||||
"id": "779fdd04-0152-4308-af08-840800a7f395",
|
"id": "779fdd04-0152-4308-af08-840800a7f395",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -324,6 +324,14 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "3ac01b16-8ac6-4487-a6f2-fd9cf33a9fe4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@@ -342,7 +350,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.12"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -79,12 +79,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||||
# Unsqueeze the mask to match dimensions
|
|
||||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
# Use the mask to fill attention scores
|
||||||
# Use the unsqueezed mask to fill attention scores
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
|
||||||
|
|
||||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
attn_weights = self.dropout(attn_weights)
|
attn_weights = self.dropout(attn_weights)
|
||||||
|
|||||||
@@ -544,7 +544,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"1.15 s ± 86.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -569,7 +569,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"273 ms ± 3.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -594,7 +594,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"324 ms ± 17.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -619,7 +619,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
"106 ms ± 598 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -644,7 +644,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"351 ms ± 7.88 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -665,7 +665,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
"333 ms ± 14.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -89,12 +89,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||||
# Unsqueeze the mask to match dimensions
|
|
||||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
# Use the mask to fill attention scores
|
||||||
# Use the unsqueezed mask to fill attention scores
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
|
||||||
|
|
||||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
attn_weights = self.dropout(attn_weights)
|
attn_weights = self.dropout(attn_weights)
|
||||||
|
|||||||
@@ -78,12 +78,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||||
# Unsqueeze the mask to match dimensions
|
|
||||||
mask_unsqueezed = mask_bool.unsqueeze(0)
|
# Use the mask to fill attention scores
|
||||||
# Use the unsqueezed mask to fill attention scores
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
|
||||||
|
|
||||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
attn_weights = self.dropout(attn_weights)
|
attn_weights = self.dropout(attn_weights)
|
||||||
|
|||||||
@@ -89,12 +89,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
# Compute scaled dot-product attention (aka self-attention) with a causal mask
|
||||||
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
|
||||||
|
|
||||||
# Original mask truncated to the number of tokens and converted to boolean
|
# Original mask truncated to the number of tokens and converted to boolean
|
||||||
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
|
||||||
# Unsqueeze the mask twice to match dimensions
|
|
||||||
mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)
|
# Use the mask to fill attention scores
|
||||||
# Use the unsqueezed mask to fill attention scores
|
attn_scores.masked_fill_(mask_bool, -torch.inf)
|
||||||
attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
|
|
||||||
|
|
||||||
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
|
||||||
attn_weights = self.dropout(attn_weights)
|
attn_weights = self.dropout(attn_weights)
|
||||||
|
|||||||
Reference in New Issue
Block a user