Fix flex attention in PyTorch 2.10 (#957)

This commit is contained in:
Sebastian Raschka
2026-02-09 15:12:40 -05:00
committed by GitHub
parent 82010e2c77
commit 7b1f740f74

View File

@@ -1017,9 +1017,10 @@
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n", " self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n", " self.dropout = dropout\n",
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n", "\n",
" # `create_block_mask` function does not support buffers, yet\n", " # Since slicing the BlockMask is no longer supported in PyTorch 2.10 and newer\n",
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n", " # we will create a new mask in the forward pass with the correct sequence length\n",
" # self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
"\n", "\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x):\n",
@@ -1041,10 +1042,15 @@
"\n", "\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n", " # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n", " # if self.context_length >= num_tokens:\n",
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n", " # attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
" else:\n", " # else:\n",
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n", " # attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
" #\n",
" #\n",
" # Commented out code lines above since slicing a BlockMask no longer works in PyTorch 3.10\n",
" # Instead, we create a fresh mask each time:\n",
" attn_mask = create_block_mask(causal, B=None, H=None, Q_LEN=num_tokens, KV_LEN=num_tokens, device=x.device)\n",
"\n", "\n",
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n", " context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
"\n", "\n",