From 7b1f740f74cbeb9e1c1c24ee19ab6e1729209240 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 9 Feb 2026 15:12:40 -0500 Subject: [PATCH] Fix flex attention in PyTorch 2.10 (#957) --- .../mha-implementations.ipynb | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index 8606d9d..1df1f1e 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -1017,9 +1017,10 @@ " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.proj = nn.Linear(d_out, d_out)\n", " self.dropout = dropout\n", - " # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n", - " # `create_block_mask` function does not support buffers, yet\n", - " self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n", + "\n", + " # Since slicing the BlockMask is no longer supported in PyTorch 2.10 and newer\n", + " # we will create a new mask in the forward pass with the correct sequence length\n", + " # self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n", "\n", "\n", " def forward(self, x):\n", @@ -1041,10 +1042,15 @@ "\n", " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", - " if self.context_length >= num_tokens:\n", - " attn_mask = self.block_mask[:num_tokens, :num_tokens]\n", - " else:\n", - " attn_mask = self.block_mask[:self.context_length, :self.context_length]\n", + " # if self.context_length >= num_tokens:\n", + " # attn_mask = self.block_mask[:num_tokens, :num_tokens]\n", + " # else:\n", + " # attn_mask = self.block_mask[:self.context_length, :self.context_length]\n", + " #\n", + " #\n", + " # Commented out code lines above since slicing a BlockMask no longer works in PyTorch 3.10\n", + " # Instead, we create a fresh mask each time:\n", + " attn_mask = create_block_mask(causal, B=None, H=None, Q_LEN=num_tokens, KV_LEN=num_tokens, device=x.device)\n", "\n", " context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n", "\n",