mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Fix flex attention in PyTorch 2.10 (#957)
This commit is contained in:
committed by
GitHub
parent
82010e2c77
commit
7b1f740f74
@@ -1017,9 +1017,10 @@
|
|||||||
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
|
||||||
" self.proj = nn.Linear(d_out, d_out)\n",
|
" self.proj = nn.Linear(d_out, d_out)\n",
|
||||||
" self.dropout = dropout\n",
|
" self.dropout = dropout\n",
|
||||||
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
|
"\n",
|
||||||
" # `create_block_mask` function does not support buffers, yet\n",
|
" # Since slicing the BlockMask is no longer supported in PyTorch 2.10 and newer\n",
|
||||||
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
|
" # we will create a new mask in the forward pass with the correct sequence length\n",
|
||||||
|
" # self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x):\n",
|
" def forward(self, x):\n",
|
||||||
@@ -1041,10 +1042,15 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
|
||||||
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
|
||||||
" if self.context_length >= num_tokens:\n",
|
" # if self.context_length >= num_tokens:\n",
|
||||||
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
|
" # attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
|
||||||
" else:\n",
|
" # else:\n",
|
||||||
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
|
" # attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
|
||||||
|
" #\n",
|
||||||
|
" #\n",
|
||||||
|
" # Commented out code lines above since slicing a BlockMask no longer works in PyTorch 3.10\n",
|
||||||
|
" # Instead, we create a fresh mask each time:\n",
|
||||||
|
" attn_mask = create_block_mask(causal, B=None, H=None, Q_LEN=num_tokens, KV_LEN=num_tokens, device=x.device)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
|
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user