Qwen3.5 from scratch (#969)

* Qwen3.5 from scratch

* update

* update
This commit is contained in:
Sebastian Raschka
2026-03-03 17:31:16 -05:00
committed by GitHub
parent 4612d20fa8
commit 7892ec9435
9 changed files with 4317 additions and 57 deletions

View File

@@ -82,9 +82,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.35.3\n",
"tokenizers version: 0.22.1\n",
"torch version: 2.8.0\n"
"huggingface_hub version: 1.5.0\n",
"tokenizers version: 0.22.2\n",
"torch version: 2.8.0+cu128\n"
]
}
],
@@ -659,16 +659,16 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n",
" [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n",
" [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n",
"tensor([[[-0.2334, -0.0134, -0.7031, ..., 0.4316, 0.1177, 1.0703],\n",
" [-0.6641, 0.5352, -0.0752, ..., -0.0698, 0.5430, 0.3203],\n",
" [-0.4785, -0.1748, 0.1074, ..., -0.2354, 0.2354, 0.6289]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
@@ -922,16 +922,7 @@
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"outputs": [],
"source": [
"import json\n",
"import os\n",
@@ -1182,16 +1173,26 @@
"<think>\n",
"Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n",
"\n",
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or content generation. Also, emphasize their adaptability and efficiency. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n",
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n",
"</think>\n",
"\n",
"Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and content creation. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for a wide range of applications."
"Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and answering questions. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for applications in various domains.\n",
"\n",
"Generation speed: 48.46 tokens/sec\n",
"GPU memory used: 1.50 GB\n"
]
}
],
"source": [
"import time\n",
"\n",
"input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
"start_time = time.perf_counter()\n",
"generated_tokens = 0\n",
"\n",
"for token in generate_text_basic_stream(\n",
" model=model,\n",
@@ -1199,12 +1200,23 @@
" max_new_tokens=500,\n",
" eos_token_id=tokenizer.eos_token_id\n",
"):\n",
" generated_tokens += 1\n",
" token_id = token.squeeze(0).tolist()\n",
" print(\n",
" tokenizer.decode(token_id),\n",
" end=\"\",\n",
" flush=True\n",
" )"
" )\n",
"\n",
"elapsed = time.perf_counter() - start_time\n",
"tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n",
"print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n",
"\n",
"if torch.cuda.is_available():\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
"\n",
" print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")\n"
]
},
{

View File

@@ -56,7 +56,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 21,
"id": "7c201adb-747e-437b-9a62-442802941e01",
"metadata": {},
"outputs": [],
@@ -66,7 +66,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 22,
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"metadata": {
"colab": {
@@ -80,9 +80,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.34.4\n",
"tokenizers version: 0.21.4\n",
"torch version: 2.8.0\n"
"huggingface_hub version: 1.5.0\n",
"tokenizers version: 0.22.2\n",
"torch version: 2.8.0+cu128\n"
]
}
],
@@ -113,7 +113,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 23,
"id": "70a90338-624a-4706-aa55-6b4358070194",
"metadata": {},
"outputs": [],
@@ -142,7 +142,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 24,
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
"metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
@@ -169,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 25,
"id": "56715760-37e1-433e-89da-04864c139a9e",
"metadata": {},
"outputs": [],
@@ -200,7 +200,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 26,
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
"metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
@@ -252,7 +252,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 27,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
@@ -327,7 +327,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 28,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
@@ -367,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 29,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@@ -431,7 +431,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 30,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
@@ -536,7 +536,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 31,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@@ -549,7 +549,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 32,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
@@ -582,7 +582,7 @@
")"
]
},
"execution_count": 12,
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
@@ -601,20 +601,20 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 33,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.2256, -0.0164, -0.7070, ..., 0.4414, 0.1245, 1.0703],\n",
" [-0.6602, 0.5352, -0.0718, ..., -0.0737, 0.5391, 0.3086],\n",
" [-0.4785, -0.1562, 0.1045, ..., -0.2324, 0.2354, 0.6328]]],\n",
"tensor([[[-0.2334, -0.0134, -0.7031, ..., 0.4316, 0.1177, 1.0703],\n",
" [-0.6641, 0.5352, -0.0752, ..., -0.0698, 0.5430, 0.3203],\n",
" [-0.4785, -0.1748, 0.1074, ..., -0.2354, 0.2354, 0.6289]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 13,
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
@@ -625,7 +625,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 34,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
@@ -656,7 +656,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 35,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
@@ -706,7 +706,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 36,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@@ -736,7 +736,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 37,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
@@ -841,7 +841,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 38,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
@@ -864,7 +864,22 @@
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b83bbaf857414e8b8842a6af8bfe3071",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/1.50G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import json\n",
"import os\n",
@@ -915,7 +930,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 39,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
@@ -996,10 +1011,25 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 40,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2bcba7591b04bfea6b518382e849767",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"tokenizer.json: 0%| | 0.00/11.4M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"if USE_REASONING_MODEL:\n",
" tokenizer_file_path = f\"Qwen3-{CHOOSE_MODEL}/tokenizer.json\"\n",
@@ -1033,7 +1063,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 41,
"id": "1946b534-e3af-431a-a222-391a60bfa892",
"metadata": {},
"outputs": [
@@ -1043,7 +1073,7 @@
"'<|im_start|>user\\nGive me a short introduction to large language models.<|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 21,
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
@@ -1069,7 +1099,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 42,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1095,7 +1125,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 43,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1108,16 +1138,26 @@
"<think>\n",
"Okay, the user wants a short introduction to large language models. Let me start by recalling what I know. Large language models are AI systems that can understand and generate human language. They're trained on massive datasets, so they can learn complex patterns and nuances.\n",
"\n",
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or content generation. Also, emphasize their adaptability and efficiency. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n",
"I should mention their ability to understand and generate text, not just specific tasks. Maybe include examples like chatbots or language assistants. Also, emphasize their adaptability and versatility. Oh, and maybe touch on their applications in various fields. Let me check if I'm covering all key points without being too technical. Keep it concise, around a sentence or two. Make sure it's clear and easy to understand.\n",
"</think>\n",
"\n",
"Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and content creation. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for applications in various domains."
"Large language models (LLMs) are AI systems designed to understand and generate human language, enabling tasks like text generation, translation, and answering questions. They are trained on vast datasets, allowing them to learn complex patterns and nuances, making them versatile for a wide range of applications.\n",
"\n",
"Generation speed: 32.84 tokens/sec\n",
"GPU memory used: 3.03 GB\n"
]
}
],
"source": [
"import time\n",
"\n",
"input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
"start_time = time.perf_counter()\n",
"generated_tokens = 0\n",
"\n",
"for token in generate_text_basic_stream(\n",
" model=model,\n",
@@ -1125,12 +1165,23 @@
" max_new_tokens=500,\n",
" eos_token_id=tokenizer.eos_token_id\n",
"):\n",
" generated_tokens += 1\n",
" token_id = token.squeeze(0).tolist()\n",
" print(\n",
" tokenizer.decode(token_id),\n",
" end=\"\",\n",
" flush=True\n",
" )"
" )\n",
"\n",
"elapsed = time.perf_counter() - start_time\n",
"tokens_per_sec = generated_tokens / elapsed if elapsed > 0 else 0.0\n",
"print(f\"\\n\\nGeneration speed: {tokens_per_sec:.2f} tokens/sec\")\n",
"\n",
"if torch.cuda.is_available():\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
"\n",
" print(f\"GPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{

20
ch05/16_qwen3.5/README.md Normal file
View File

@@ -0,0 +1,20 @@
# Qwen3.5 0.8B From Scratch
This folder contains a from-scratch style implementation of [Qwen/Qwen3.5-0.8B](https://huggingface.co/Qwen/Qwen3.5-0.8B).
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen3.5/01.webp" width="500px">
Qwen3.5 is based on the Qwen3-Next architecture, which I described in more detail in section [2. (Linear) Attention Hybrids](https://magazine.sebastianraschka.com/i/177848019/2-linear-attention-hybrids) of my [Beyond Standard LLMs](https://magazine.sebastianraschka.com/p/beyond-standard-llms) article
<a href="https://magazine.sebastianraschka.com/p/beyond-standard-llms"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/qwen3.5/02.webp" width="500px"></a>
Note that Qwen3.5 alternates `linear_attention` and `full_attention` layers.
The notebooks keep the full model flow readable while reusing the linear-attention building blocks from the [qwen3_5_transformers.py](qwen3_5_transformers.py), which contains the linear attention code from Hugging Face under an Apache version 2.0 open source license.
&nbsp;
## Files
- [qwen3.5.ipynb](qwen3.5.ipynb): Main Qwen3.5 0.8B notebook implementation.
- [qwen3.5-plus-kv-cache.ipynb](qwen3.5-plus-kv-cache.ipynb): Same model with KV-cache decoding for efficiency.
- [qwen3_5_transformers.py](qwen3_5_transformers.py): Some helper components from Hugging Face Transformers used for Qwen3.5 linear attention.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,425 @@
"""Qwen3.5 helper blocks copied from Hugging Face Transformers
Source file:
transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
License: Apache License Version 2.0
License URL: https://github.com/huggingface/transformers/blob/main/LICENSE
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# Notebook shims for optional fast kernels in transformers
causal_conv1d_fn = None
causal_conv1d_update = None
chunk_gated_delta_rule = None
fused_recurrent_gated_delta_rule = None
FusedRMSNormGated = None
ACT2FN = {"silu": F.silu}
is_fast_path_available = False
class _NotebookLogger:
def __init__(self):
self._seen = set()
def warning_once(self, msg):
if msg in self._seen:
return
self._seen.add(msg)
print(msg)
logger = _NotebookLogger()
# Placeholder types for copied annotations
class Qwen3_5Config:
pass
class Qwen3_5DynamicCache:
pass
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
class Qwen3_5RMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, **kwargs):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
# Norm before gate
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
# NOTE: attention mask is a 2D boolean tensor
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
def torch_causal_conv1d_update(
hidden_states,
conv_state,
weight,
bias=None,
activation=None,
):
_, hidden_size, seq_len = hidden_states.shape
state_len = conv_state.shape[-1]
hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype)
conv_state.copy_(hidden_states_new[:, :, -state_len:])
out = F.conv1d(hidden_states_new, weight.unsqueeze(1), bias, padding=0, groups=hidden_size)
out = F.silu(out[:, :, -seq_len:])
out = out.to(hidden_states.dtype)
return out
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
def l2norm(x, dim=-1, eps=1e-6):
"""This function is intended to align with the l2norm implementation in the FLA library."""
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
return x * inv_norm
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = F.pad(query, (0, 0, 0, pad_size))
key = F.pad(key, (0, 0, 0, pad_size))
value = F.pad(value, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)
# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)
# for each chunk
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new
)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
# Copied verbatim from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
def torch_recurrent_gated_delta_rule(
query, key, value, g, beta, initial_state, output_final_state, use_qk_l2norm_in_kernel=False
):
initial_dtype = query.dtype
if use_qk_l2norm_in_kernel:
query = l2norm(query, dim=-1, eps=1e-6)
key = l2norm(key, dim=-1, eps=1e-6)
query, key, value, beta, g = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
]
batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale
core_attn_out = torch.zeros(batch_size, num_heads, sequence_length, v_head_dim).to(value)
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
)
for i in range(sequence_length):
q_t = query[:, :, i]
k_t = key[:, :, i]
v_t = value[:, :, i]
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
beta_t = beta[:, :, i].unsqueeze(-1)
last_recurrent_state = last_recurrent_state * g_t
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
delta = (v_t - kv_mem) * beta_t
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2)
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2)
if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
# Copied from:
# transformers-main/src/transformers/models/qwen3_5/modeling_qwen3_5.py
# Minimal change: enforce config dtype at the end to avoid bf16/fp32 matmul mismatch
# in a mixed notebook implementation
class Qwen3_5GatedDeltaNet(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.hidden_size = config.hidden_size
self.num_v_heads = config.linear_num_value_heads
self.num_k_heads = config.linear_num_key_heads
self.head_k_dim = config.linear_key_head_dim
self.head_v_dim = config.linear_value_head_dim
self.key_dim = self.head_k_dim * self.num_k_heads
self.value_dim = self.head_v_dim * self.num_v_heads
self.conv_kernel_size = config.linear_conv_kernel_dim
self.layer_idx = layer_idx
self.activation = config.hidden_act
self.act = ACT2FN[config.hidden_act]
self.layer_norm_epsilon = config.rms_norm_eps
# QKV
self.conv_dim = self.key_dim * 2 + self.value_dim
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=False,
kernel_size=self.conv_kernel_size,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
)
# time step projection (discretization)
# instantiate once and copy inv_dt in init_weights of PretrainedModel
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
A = torch.empty(self.num_v_heads).uniform_(0, 16)
self.A_log = nn.Parameter(torch.log(A))
self.norm = (
Qwen3_5RMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
if FusedRMSNormGated is None
else FusedRMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
self.causal_conv1d_fn = causal_conv1d_fn
self.causal_conv1d_update = causal_conv1d_update or torch_causal_conv1d_update
self.chunk_gated_delta_rule = chunk_gated_delta_rule or torch_chunk_gated_delta_rule
self.recurrent_gated_delta_rule = fused_recurrent_gated_delta_rule or torch_recurrent_gated_delta_rule
if not is_fast_path_available:
logger.warning_once(
"The fast path is not available because one of the required library is not installed. Falling back to "
"torch implementation. To install follow https://github.com/fla-org/flash-linear-attention#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
)
self.in_proj_qkv = nn.Linear(self.hidden_size, self.key_dim * 2 + self.value_dim, bias=False)
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
# Notebook adaptation for dtype consistency.
if config.dtype is not None:
self.to(dtype=config.dtype)
def forward(
self,
hidden_states,
cache_params=None,
cache_position=None,
attention_mask=None,
):
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
# Set up dimensions for reshapes later
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = (
cache_params is not None
and cache_params.has_previous_state
and seq_len == 1
and cache_position is not None
)
# getting projected states from cache if it exists
if cache_params is not None:
conv_state = cache_params.conv_states[self.layer_idx]
recurrent_state = cache_params.recurrent_states[self.layer_idx]
mixed_qkv = self.in_proj_qkv(hidden_states)
mixed_qkv = mixed_qkv.transpose(1, 2)
z = self.in_proj_z(hidden_states)
z = z.reshape(batch_size, seq_len, -1, self.head_v_dim)
b = self.in_proj_b(hidden_states)
a = self.in_proj_a(hidden_states)
if use_precomputed_states:
# 2. Convolution sequence transformation
# NOTE: the conv state is updated in `causal_conv1d_update`
mixed_qkv = self.causal_conv1d_update(
mixed_qkv,
conv_state,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
else:
if cache_params is not None:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
if self.causal_conv1d_fn is not None:
mixed_qkv = self.causal_conv1d_fn(
x=mixed_qkv,
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
seq_idx=None,
)
else:
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)
query, key, value = torch.split(
mixed_qkv,
[
self.key_dim,
self.key_dim,
self.value_dim,
],
dim=-1,
)
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
beta = b.sigmoid()
# If the model is loaded in fp16, without the .float() here, A might be -inf
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
if self.num_v_heads // self.num_k_heads > 1:
query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2)
if not use_precomputed_states:
core_attn_out, last_recurrent_state = self.chunk_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=None,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
else:
core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule(
query,
key,
value,
g=g,
beta=beta,
initial_state=recurrent_state,
output_final_state=cache_params is not None,
use_qk_l2norm_in_kernel=True,
)
# Update cache
if cache_params is not None:
cache_params.recurrent_states[self.layer_idx] = last_recurrent_state
# reshape input data into 2D tensor
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
z = z.reshape(-1, self.head_v_dim)
core_attn_out = self.norm(core_attn_out, z)
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
output = self.out_proj(core_attn_out)
return output

View File

@@ -0,0 +1,275 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import sys
from pathlib import Path
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
def _import_qwen3_5_classes():
try:
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
return Qwen3_5TextConfig, Qwen3_5ForCausalLM
except Exception:
repo_root = Path(__file__).resolve().parents[3]
local_src = repo_root / "transformers-main" / "src"
if not local_src.exists():
raise
for name in list(sys.modules):
if name == "transformers" or name.startswith("transformers."):
del sys.modules[name]
sys.path.insert(0, str(local_src))
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
return Qwen3_5TextConfig, Qwen3_5ForCausalLM
try:
Qwen3_5TextConfig, Qwen3_5ForCausalLM = _import_qwen3_5_classes()
except Exception:
Qwen3_5TextConfig = None
Qwen3_5ForCausalLM = None
def tiny_debug_config():
return {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"partial_rotary_factor": 1.0,
"rms_norm_eps": 1e-6,
"linear_conv_kernel_dim": 2,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"layer_types": ["linear_attention", "full_attention"],
"dtype": torch.float32,
}
def _hf_config_from_dict(cfg):
if Qwen3_5TextConfig is None:
raise ImportError("Qwen3.5 classes are required for the layer debugger.")
hf_cfg = Qwen3_5TextConfig(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
layer_types=cfg["layer_types"],
linear_conv_kernel_dim=cfg["linear_conv_kernel_dim"],
linear_key_head_dim=cfg["linear_key_head_dim"],
linear_value_head_dim=cfg["linear_value_head_dim"],
linear_num_key_heads=cfg["linear_num_key_heads"],
linear_num_value_heads=cfg["linear_num_value_heads"],
tie_word_embeddings=False,
use_cache=False,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=cfg.get("rms_norm_eps", 1e-6),
rope_parameters={
"rope_type": "default",
"rope_theta": cfg["rope_base"],
"partial_rotary_factor": cfg.get("partial_rotary_factor", 1.0),
"mrope_interleaved": True,
"mrope_section": [2, 1, 1],
},
torch_dtype=cfg.get("dtype", torch.float32),
)
hf_cfg._attn_implementation = "eager"
return hf_cfg
def load_notebook_defs(nb_name="qwen3.5.ipynb"):
nb_dir = Path(__file__).resolve().parents[1]
if str(nb_dir) not in sys.path:
sys.path.insert(0, str(nb_dir))
return import_definitions_from_notebook(nb_dir, nb_name)
def build_qwen3_5_pair(import_notebook_defs, cfg, hf_checkpoint=None):
if Qwen3_5ForCausalLM is None:
raise ImportError("Qwen3.5 classes are required for the layer debugger.")
ours = import_notebook_defs.Qwen3_5Model(cfg)
if hf_checkpoint:
hf_model = Qwen3_5ForCausalLM.from_pretrained(
hf_checkpoint,
torch_dtype=cfg.get("dtype", torch.float32),
attn_implementation="eager",
)
else:
hf_cfg = _hf_config_from_dict(cfg)
hf_model = Qwen3_5ForCausalLM(hf_cfg)
import_notebook_defs.load_weights_into_qwen3_5(
ours,
{"n_layers": cfg["n_layers"], "layer_types": cfg["layer_types"]},
hf_model.state_dict(),
)
hf_model.config.use_cache = False
ours.eval()
hf_model.eval()
return ours, hf_model
def _attach_debug_hooks(model, is_hf):
traces = {}
handles = []
def hook(name):
def _record(_, __, output):
if isinstance(output, tuple):
output = output[0]
traces[name] = output.detach().to(torch.float32).cpu()
return _record
if is_hf:
core = model.model
handles.append(core.embed_tokens.register_forward_hook(hook("embedding")))
for idx, layer in enumerate(core.layers):
handles.append(layer.register_forward_hook(hook(f"block_{idx}")))
handles.append(core.norm.register_forward_hook(hook("final_norm")))
handles.append(model.lm_head.register_forward_hook(hook("logits")))
else:
handles.append(model.tok_emb.register_forward_hook(hook("embedding")))
blocks = getattr(model, "trf_blocks", None)
if blocks is None:
blocks = getattr(model, "blocks", None)
if blocks is None:
raise AttributeError("Could not locate Qwen3.5 blocks on the local model.")
for idx, block in enumerate(blocks):
handles.append(block.register_forward_hook(hook(f"block_{idx}")))
handles.append(model.final_norm.register_forward_hook(hook("final_norm")))
handles.append(model.out_head.register_forward_hook(hook("logits")))
return traces, handles
def _layer_sort_key(name):
if name == "embedding":
return (0, 0)
if name.startswith("block_"):
idx = int(name.split("_")[1])
return (1, idx)
if name == "final_norm":
return (2, 0)
if name == "logits":
return (3, 0)
return (4, name)
def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5):
ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False)
hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True)
try:
with torch.inference_mode():
ours(input_ids)
hf_model(input_ids, use_cache=False)
finally:
for h in ours_handles + hf_handles:
h.remove()
layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key)
results = []
for name in layer_names:
ours_tensor = ours_traces.get(name)
hf_tensor = hf_traces.get(name)
if ours_tensor is None or hf_tensor is None:
results.append(
{
"name": name,
"status": "missing",
"ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape),
"hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
if ours_tensor.shape != hf_tensor.shape:
results.append(
{
"name": name,
"status": "shape_mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
diff = (ours_tensor - hf_tensor).abs()
max_diff = float(diff.max().item())
mean_diff = float(diff.mean().item())
allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol)
results.append(
{
"name": name,
"status": "ok" if allclose else "mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": max_diff,
"mean_abs_diff": mean_diff,
}
)
return results
def format_report(differences):
lines = []
for diff in sorted(differences, key=lambda d: _layer_sort_key(d["name"])):
if diff["status"] == "ok":
lines.append(f"[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
elif diff["status"] == "mismatch":
lines.append(f"[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
elif diff["status"] == "shape_mismatch":
lines.append(f"[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
else:
lines.append(f"[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
return "\n".join(lines)
if __name__ == "__main__":
if Qwen3_5ForCausalLM is None:
raise SystemExit(
"Qwen3.5 classes are unavailable. Install a recent transformers version or use local transformers-main."
)
import_notebook_defs = load_notebook_defs()
cfg = tiny_debug_config()
ours_model, hf_model = build_qwen3_5_pair(import_notebook_defs, cfg)
torch.manual_seed(0)
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
diffs = layerwise_differences(ours_model, hf_model, input_ids)
print(format_report(diffs))

View File

@@ -0,0 +1,166 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import importlib
import sys
from pathlib import Path
import pytest
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
def _import_qwen3_5_classes():
try:
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
return Qwen3_5TextConfig, Qwen3_5ForCausalLM
except Exception:
repo_root = Path(__file__).resolve().parents[3]
local_src = repo_root / "transformers-main" / "src"
if not local_src.exists():
raise
for name in list(sys.modules):
if name == "transformers" or name.startswith("transformers."):
del sys.modules[name]
sys.path.insert(0, str(local_src))
from transformers.models.qwen3_5.configuration_qwen3_5 import Qwen3_5TextConfig
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5ForCausalLM
return Qwen3_5TextConfig, Qwen3_5ForCausalLM
transformers_installed = importlib.util.find_spec("transformers") is not None
if transformers_installed:
try:
Qwen3_5TextConfig, Qwen3_5ForCausalLM = _import_qwen3_5_classes()
except Exception:
transformers_installed = False
Qwen3_5TextConfig, Qwen3_5ForCausalLM = None, None
else:
Qwen3_5TextConfig, Qwen3_5ForCausalLM = None, None
@pytest.fixture
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
if str(nb_dir) not in sys.path:
sys.path.insert(0, str(nb_dir))
mod = import_definitions_from_notebook(nb_dir, "qwen3.5.ipynb")
return mod
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
return torch.randint(0, 100, (1, 8))
@pytest.fixture
def dummy_cfg_base():
return {
"vocab_size": 100,
"emb_dim": 32,
"hidden_dim": 64,
"n_layers": 2,
"n_heads": 4,
"head_dim": 8,
"n_kv_groups": 1,
"qk_norm": False,
"dtype": torch.float32,
"rope_base": 10_000.0,
"context_length": 64,
"partial_rotary_factor": 1.0,
"rms_norm_eps": 1e-6,
"linear_conv_kernel_dim": 2,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"layer_types": ["linear_attention", "full_attention"],
}
@torch.inference_mode()
def test_dummy_qwen3_5_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = import_notebook_defs.Qwen3_5Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), (
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
)
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_5_base_equivalence_with_transformers(import_notebook_defs):
cfg = {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"partial_rotary_factor": 1.0,
"rms_norm_eps": 1e-6,
"linear_conv_kernel_dim": 2,
"linear_key_head_dim": 8,
"linear_value_head_dim": 8,
"linear_num_key_heads": 2,
"linear_num_value_heads": 2,
"layer_types": ["linear_attention", "full_attention"],
"dtype": torch.float32,
}
model = import_notebook_defs.Qwen3_5Model(cfg)
hf_cfg = Qwen3_5TextConfig(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
layer_types=cfg["layer_types"],
linear_conv_kernel_dim=cfg["linear_conv_kernel_dim"],
linear_key_head_dim=cfg["linear_key_head_dim"],
linear_value_head_dim=cfg["linear_value_head_dim"],
linear_num_key_heads=cfg["linear_num_key_heads"],
linear_num_value_heads=cfg["linear_num_value_heads"],
tie_word_embeddings=False,
use_cache=False,
attention_bias=False,
attention_dropout=0.0,
rms_norm_eps=cfg["rms_norm_eps"],
rope_parameters={
"rope_type": "default",
"rope_theta": cfg["rope_base"],
"partial_rotary_factor": cfg["partial_rotary_factor"],
"mrope_interleaved": True,
"mrope_section": [2, 1, 1],
},
torch_dtype=torch.float32,
)
hf_cfg._attn_implementation = "eager"
hf_model = Qwen3_5ForCausalLM(hf_cfg)
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "layer_types": cfg["layer_types"]}
import_notebook_defs.load_weights_into_qwen3_5(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)
theirs_logits = hf_model(x, use_cache=False).logits
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)