Qwen3 and Llama3 equivalency teests with HF transformers (#768)

* Qwen3 and Llama3 equivalency teests with HF transformers

* update
This commit is contained in:
Sebastian Raschka
2025-08-14 18:36:07 -05:00
committed by GitHub
parent 2e3205f747
commit 07c3122b5c
6 changed files with 199 additions and 8 deletions

View File

@@ -5,11 +5,12 @@
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import (
compute_rope_params,
apply_rope,
LLAMA32_CONFIG_1B,
compute_rope_params,
GroupedQueryAttention,
GroupedQueryAttentionFast,
load_weights_into_llama,
LLAMA32_CONFIG_1B,
Llama3Model,
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
@@ -246,3 +247,61 @@ def test_rmsnorm_equivalence():
out2 = lit_norm(x)
torch.testing.assert_close(out1, out2, atol=1e-5, rtol=1e-5)
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers():
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
"context_length": 8192,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"n_kv_groups": 2,
"rope_base": 500_000.0,
"rope_freq": {
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
},
"dtype": torch.float32,
}
ours = Llama3Model(cfg)
hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_key_value_heads=cfg["n_kv_groups"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
max_position_embeddings=cfg["context_length"],
rms_norm_eps=1e-5,
attention_bias=False,
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
rope_scaling={
"type": "llama3",
"factor": cfg["rope_freq"]["factor"],
"low_freq_factor": cfg["rope_freq"]["low_freq_factor"],
"high_freq_factor": cfg["rope_freq"]["high_freq_factor"],
"original_max_position_embeddings": cfg["rope_freq"]["original_context_length"],
},
)
theirs = LlamaForCausalLM(hf_cfg)
hf_state = theirs.state_dict()
load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

View File

@@ -5,12 +5,13 @@
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.qwen3 import (
compute_rope_params,
apply_rope,
compute_rope_params,
load_weights_into_qwen,
QWEN_CONFIG_06_B,
RMSNorm,
Qwen3Model,
Qwen3Tokenizer
Qwen3Tokenizer,
RMSNorm,
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.utils import KVCache
@@ -87,6 +88,7 @@ def dummy_cfg_moe(dummy_cfg_base):
return cfg
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_base)
@@ -95,6 +97,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input):
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@torch.inference_mode()
def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
torch.manual_seed(123)
model = Qwen3Model(dummy_cfg_moe)
@@ -105,6 +108,7 @@ def test_dummy_qwen3_moe_forward(dummy_cfg_moe, dummy_input):
"Expected MoEFeedForward in at least one transformer block"
@torch.inference_mode()
@pytest.mark.parametrize("cfg_name", ["dummy_cfg_base", "dummy_cfg_moe"])
def test_qwen3_kvcache_equivalence(cfg_name, request):
cfg = request.getfixturevalue(cfg_name)
@@ -438,3 +442,51 @@ def test_tokenizer_equivalence():
expected_pad_token = "<|endoftext|>"
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers():
from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
cfg = {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"dtype": torch.float32,
}
model = Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
)
hf_model = Qwen3ForCausalLM(hf_cfg)
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)
theirs_logits = hf_model(x).logits
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)