Qwen3 Coder Flash & MoE from Scratch (#760)

* Qwen3 Coder Flash & MoE from Scratch

* update

* refinements

* updates

* update

* update

* update
This commit is contained in:
Sebastian Raschka
2025-08-01 19:13:17 -05:00
committed by GitHub
parent 145322ded8
commit f92b40e4ab
13 changed files with 2972 additions and 271 deletions

View File

@@ -160,10 +160,16 @@ from llms_from_scratch.qwen3 import (
# KV cache drop-in replacements
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
from llms_from_scratch.kv_cache.generate import (
generate_text_simple,
generate_text_simple_stream
)
# KV cache drop-in replacements with batched inference support
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
from llms_from_scratch.kv_cache_batched.generate import (
generate_text_simple,
generate_text_simple_stream
)
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
```

View File

@@ -28,3 +28,27 @@ def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cach
idx = torch.cat([idx, next_idx], dim=1)
return idx
def generate_text_simple_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):
model.eval()
with torch.no_grad():
cache = KVCache(n_layers=model.cfg["n_layers"])
model.reset_kv_cache()
# Prime the cache with the initial context
logits = model(token_ids, cache=cache)
for _ in range(max_new_tokens):
next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)
if eos_token_id is not None and torch.all(next_token == eos_token_id):
break
yield next_token
token_ids = torch.cat([token_ids, next_token], dim=1)
# Feed only the new token to the model; cache handles history
logits = model(next_token, cache=cache)

View File

@@ -29,7 +29,7 @@ class Qwen3Model(nn.Module):
self.final_norm = RMSNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
# Reusable utilities
if cfg["head_dim"] is None:
head_dim = cfg["emb_dim"] // cfg["n_heads"]
else:
@@ -94,7 +94,10 @@ class TransformerBlock(nn.Module):
qk_norm=cfg["qk_norm"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
if "num_experts" in cfg and cfg["num_experts"] > 0:
self.ff = MoEFeedForward(cfg)
else:
self.ff = FeedForward(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
@@ -128,6 +131,46 @@ class FeedForward(nn.Module):
return self.fc3(x)
class MoEFeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_experts_per_tok = cfg["num_experts_per_tok"]
self.num_experts = cfg["num_experts"]
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
def forward(self, x):
scores = self.gate(x) # (b, seq_len, num_experts)
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
topk_probs = torch.softmax(topk_scores, dim=-1)
expert_outputs = []
for e in range(self.num_experts):
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
out = self.fc3[e](hidden)
expert_outputs.append(out.unsqueeze(-2))
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
gating_probs = torch.zeros_like(scores)
for i in range(self.num_experts_per_tok):
indices = topk_indices[..., i:i+1]
prob = topk_probs[..., i:i+1]
gating_probs.scatter_(dim=-1, index=indices, src=prob)
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
# Weighted sum over experts
y = (gating_probs * expert_outputs).sum(dim=-2)
return y
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None

View File

@@ -102,6 +102,23 @@ QWEN3_CONFIG_32B = {
"dtype": torch.bfloat16,
}
# Mixture of Experts Model
QWEN3_CONFIG_30B_A3B = {
"vocab_size": 151_936,
"context_length": 262_144,
"emb_dim": 2048,
"n_heads": 32,
"n_layers": 48,
"head_dim": 128,
"qk_norm": True,
"n_kv_groups": 4,
"rope_base": 10_000_000.0,
"dtype": torch.bfloat16,
"num_experts": 128,
"num_experts_per_tok": 8,
"moe_intermediate_size": 768,
}
class Qwen3Model(nn.Module):
def __init__(self, cfg):
@@ -156,7 +173,10 @@ class TransformerBlock(nn.Module):
qk_norm=cfg["qk_norm"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
if "num_experts" in cfg and cfg["num_experts"] > 0:
self.ff = MoEFeedForward(cfg)
else:
self.ff = FeedForward(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
@@ -190,6 +210,46 @@ class FeedForward(nn.Module):
return self.fc3(x)
class MoEFeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.num_experts_per_tok = cfg["num_experts_per_tok"]
self.num_experts = cfg["num_experts"]
self.gate = nn.Linear(cfg["emb_dim"], cfg["num_experts"], bias=False, dtype=cfg["dtype"])
meta_device = torch.device("meta") # to reduce memory pressure and only load them when used (trades compute for memory)
self.fc1 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
self.fc2 = nn.ModuleList([nn.Linear(cfg["emb_dim"], cfg["moe_intermediate_size"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
self.fc3 = nn.ModuleList([nn.Linear(cfg["moe_intermediate_size"], cfg["emb_dim"], bias=False, dtype=cfg["dtype"], device=meta_device)
for _ in range(cfg["num_experts"])])
def forward(self, x):
scores = self.gate(x) # (b, seq_len, num_experts)
topk_scores, topk_indices = torch.topk(scores, self.num_experts_per_tok, dim=-1)
topk_probs = torch.softmax(topk_scores, dim=-1)
expert_outputs = []
for e in range(self.num_experts):
hidden = torch.nn.functional.silu(self.fc1[e](x)) * self.fc2[e](x)
out = self.fc3[e](hidden)
expert_outputs.append(out.unsqueeze(-2))
expert_outputs = torch.cat(expert_outputs, dim=-2) # (b, t, num_experts, emb_dim)
gating_probs = torch.zeros_like(scores)
for i in range(self.num_experts_per_tok):
indices = topk_indices[..., i:i+1]
prob = topk_probs[..., i:i+1]
gating_probs.scatter_(dim=-1, index=indices, src=prob)
gating_probs = gating_probs.unsqueeze(-1) # (b, t, num_experts, 1)
# Weighted sum over experts
y = (gating_probs * expert_outputs).sum(dim=-2)
return y
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
@@ -381,21 +441,53 @@ def load_weights_into_qwen(model, param_config, params):
)
# Feedforward weights
block.ff.fc1.weight = assign(
block.ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
block.ff.fc2.weight = assign(
block.ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
block.ff.fc3.weight = assign(
block.ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
if "num_experts" in param_config:
# Load router (gating) weights
block.ff.gate.weight = assign(
block.ff.gate.weight,
params[f"model.layers.{l}.mlp.gate.weight"],
f"model.layers.{l}.mlp.gate.weight"
)
# Load expert weights
for e in range(param_config["num_experts"]):
prefix = f"model.layers.{l}.mlp.experts.{e}"
block.ff.fc1[e].weight = assign(
block.ff.fc1[e].weight,
params[f"{prefix}.gate_proj.weight"],
f"{prefix}.gate_proj.weight"
)
block.ff.fc2[e].weight = assign(
block.ff.fc2[e].weight,
params[f"{prefix}.up_proj.weight"],
f"{prefix}.up_proj.weight"
)
block.ff.fc3[e].weight = assign(
block.ff.fc3[e].weight,
params[f"{prefix}.down_proj.weight"],
f"{prefix}.down_proj.weight"
)
# After assigning weights, move the expert layers from meta to CPU
block.ff.fc1[e] = block.ff.fc1[e].to("cpu")
block.ff.fc2[e] = block.ff.fc2[e].to("cpu")
block.ff.fc3[e] = block.ff.fc3[e].to("cpu")
else:
block.ff.fc1.weight = assign(
block.ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
block.ff.fc2.weight = assign(
block.ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
block.ff.fc3.weight = assign(
block.ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
block.norm2.scale = assign(
block.norm2.scale,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
@@ -405,8 +497,12 @@ def load_weights_into_qwen(model, param_config, params):
# Final normalization and output head
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
# Model uses weight tying, hence we reuse the embedding layer weights here
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
if "lm_head.weight" in params:
model.out_head.weight = assign(model.out_head.weight, params["lm_head.weight"], "lm_head.weight")
else:
# Model uses weight tying, hence we reuse the embedding layer weights here
print("Model uses weight tying.")
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
class Qwen3Tokenizer:

View File

@@ -13,12 +13,14 @@ from llms_from_scratch.qwen3 import (
Qwen3Tokenizer
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.utils import KVCache
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
import importlib
import platform
import pytest
import torch
import torch.nn as nn
@@ -50,6 +52,92 @@ class Qwen3RMSNorm(nn.Module):
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
@pytest.fixture
def dummy_cfg_base():
return {
"vocab_size": 100,
"emb_dim": 32,
"hidden_dim": 64,
"n_layers": 2,
"n_heads": 4,
"head_dim": 8,
"n_kv_groups": 1,
"qk_norm": False,
"dtype": torch.float32,
"rope_base": 10000,
"context_length": 64,
"num_experts": 0,
}
@pytest.fixture
def dummy_cfg_moe(dummy_cfg_base):
cfg = dummy_cfg_base.copy()
cfg.update({
"num_experts": 4,
"num_experts_per_tok": 2,
"moe_intermediate_size": 64,
})
return cfg
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_moe)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_moe["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
assert any(hasattr(block.ff, 'gate') for block in model.trf_blocks), \
"Expected MoEFeedForward in at least one transformer block"
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
def test_qwen3_kvcache_equivalence(cfg_name, request):
cfg = request.getfixturevalue(cfg_name)
if cfg["num_experts"] > 0 and platform.system() == "Linux":
pytest.skip("Skipping MoE KV equivalence test on Linux due to nondeterministic expert routing")
torch.manual_seed(123)
model_regular = Qwen3Model(cfg)
model_regular.eval()
model_kv = Qwen3ModelKV(cfg)
model_kv.eval()
model_kv.load_state_dict(model_regular.state_dict())
model_kv.reset_kv_cache()
cache = KVCache(n_layers=cfg["n_layers"])
torch.manual_seed(123)
input_ids = torch.randint(0, cfg["vocab_size"], (1, 6))
out_full = model_regular(input_ids)
logits_stepwise = []
for t in range(input_ids.size(1)):
input_token = input_ids[:, t:t + 1]
logits = model_kv(input_token, cache=cache)
logits_stepwise.append(logits)
out_kv = torch.cat(logits_stepwise, dim=1)
assert out_full.shape == out_kv.shape, f"Shape mismatch: {out_full.shape} vs {out_kv.shape}"
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_rope():