remove redundant unsqueeze in mask

This commit is contained in:
rasbt
2024-03-09 17:42:25 -06:00
parent 6ba97adaee
commit da33ce8054
7 changed files with 45 additions and 37 deletions

View File

@@ -148,7 +148,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
"metadata": {},
"outputs": [],
@@ -196,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 7,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {},
"outputs": [
@@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"id": "2773c09d-c136-4372-a2be-04b58d292842",
"metadata": {},
"outputs": [],
@@ -276,12 +276,12 @@
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" \n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
"\n",
" # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
" \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
@@ -298,7 +298,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"id": "779fdd04-0152-4308-af08-840800a7f395",
"metadata": {},
"outputs": [
@@ -324,6 +324,14 @@
"\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ac01b16-8ac6-4487-a6f2-fd9cf33a9fe4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -342,7 +350,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,