Qwen3 Coder Flash & MoE from Scratch (#760)

* Qwen3 Coder Flash & MoE from Scratch

* update

* refinements

* updates

* update

* update

* update
This commit is contained in:
Sebastian Raschka
2025-08-01 19:13:17 -05:00
committed by GitHub
parent 145322ded8
commit f92b40e4ab
13 changed files with 2972 additions and 271 deletions

View File

@@ -13,12 +13,14 @@ from llms_from_scratch.qwen3 import (
Qwen3Tokenizer
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.utils import KVCache
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
import importlib
import platform
import pytest
import torch
import torch.nn as nn
@@ -50,6 +52,92 @@ class Qwen3RMSNorm(nn.Module):
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
@pytest.fixture
def dummy_cfg_base():
return {
"vocab_size": 100,
"emb_dim": 32,
"hidden_dim": 64,
"n_layers": 2,
"n_heads": 4,
"head_dim": 8,
"n_kv_groups": 1,
"qk_norm": False,
"dtype": torch.float32,
"rope_base": 10000,
"context_length": 64,
"num_experts": 0,
}
@pytest.fixture
def dummy_cfg_moe(dummy_cfg_base):
cfg = dummy_cfg_base.copy()
cfg.update({
"num_experts": 4,
"num_experts_per_tok": 2,
"moe_intermediate_size": 64,
})
return cfg
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_moe)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_moe["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
assert any(hasattr(block.ff, 'gate') for block in model.trf_blocks), \
"Expected MoEFeedForward in at least one transformer block"
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
def test_qwen3_kvcache_equivalence(cfg_name, request):
cfg = request.getfixturevalue(cfg_name)
if cfg["num_experts"] > 0 and platform.system() == "Linux":
pytest.skip("Skipping MoE KV equivalence test on Linux due to nondeterministic expert routing")
torch.manual_seed(123)
model_regular = Qwen3Model(cfg)
model_regular.eval()
model_kv = Qwen3ModelKV(cfg)
model_kv.eval()
model_kv.load_state_dict(model_regular.state_dict())
model_kv.reset_kv_cache()
cache = KVCache(n_layers=cfg["n_layers"])
torch.manual_seed(123)
input_ids = torch.randint(0, cfg["vocab_size"], (1, 6))
out_full = model_regular(input_ids)
logits_stepwise = []
for t in range(input_ids.size(1)):
input_token = input_ids[:, t:t + 1]
logits = model_kv(input_token, cache=cache)
logits_stepwise.append(logits)
out_kv = torch.cat(logits_stepwise, dim=1)
assert out_full.shape == out_kv.shape, f"Shape mismatch: {out_full.shape} vs {out_kv.shape}"
assert torch.allclose(out_full, out_kv, atol=1e-5, rtol=1e-3)
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_rope():