Update unit tests for CI (#952)

* Update CI

* Revert submodule pointer update

* Update

* update

* update
This commit is contained in:
Sebastian Raschka
2026-01-27 17:44:55 -06:00
committed by GitHub
parent 59d9262047
commit e155d1b02c
5 changed files with 86 additions and 9 deletions

View File

@@ -79,5 +79,4 @@ jobs:
shell: bash
run: |
source .venv/bin/activate
uv pip install transformers
pytest pkg/llms_from_scratch/tests/

View File

@@ -177,6 +177,10 @@ def test_rope_llama2(notebook):
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {"rope_type": "default", "rope_theta": theta_base}
def standardize_rope_params(self):
return
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
@@ -242,6 +246,10 @@ def test_rope_llama3(notebook):
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {"rope_type": "default", "rope_theta": theta_base}
def standardize_rope_params(self):
return
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
@@ -320,6 +328,10 @@ def test_rope_llama3_12(notebook):
max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rope_parameters = {**hf_rope_params, "rope_theta": rope_theta}
def standardize_rope_params(self):
return
config = RoPEConfig()

View File

@@ -111,6 +111,25 @@ def test_rope():
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
def __init__(self):
# Transformers >=5.0.0 expects `rope_parameters` on the instance.
self.rope_parameters = {**hf_rope_params, "rope_theta": rope_theta}
def standardize_rope_params(self):
params = dict(getattr(self, "rope_parameters", {}) or {})
if "rope_type" not in params:
params["rope_type"] = getattr(self, "rope_type", "default")
if "rope_theta" not in params:
params["rope_theta"] = getattr(self, "rope_theta")
# Handle older key name used in this repo.
if (
"original_max_position_embeddings" not in params
and "original_context_length" in params
):
params["original_max_position_embeddings"] = params["original_context_length"]
self.rope_parameters = params
return params
config = RoPEConfig()
rot_emb = LlamaRotaryEmbedding(config=config)
@@ -304,4 +323,4 @@ def test_llama3_base_equivalence_with_transformers():
ours_logits = ours(x)
theirs_logits = theirs(x).logits.to(ours_logits.dtype)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)

View File

@@ -28,6 +28,7 @@ import os
import shutil
import tempfile
import platform
from collections.abc import Mapping
import pytest
import torch
import torch.nn as nn
@@ -59,6 +60,36 @@ class Qwen3RMSNorm(nn.Module):
transformers_installed = importlib.util.find_spec("transformers") is not None
def _hf_ids(obj):
"""Normalize HF chat-template outputs across Transformers versions."""
if isinstance(obj, Mapping):
if "input_ids" in obj:
obj = obj["input_ids"]
elif "ids" in obj:
obj = obj["ids"]
elif hasattr(obj, "keys") and hasattr(obj, "__getitem__"):
# Some HF containers behave like mappings but don't register as Mapping.
try:
if "input_ids" in obj:
obj = obj["input_ids"]
elif "ids" in obj:
obj = obj["ids"]
except Exception:
pass
if hasattr(obj, "input_ids"):
obj = obj.input_ids
if hasattr(obj, "ids"):
obj = obj.ids
if isinstance(obj, torch.Tensor):
obj = obj.tolist()
if isinstance(obj, tuple):
obj = list(obj)
# Some HF versions return a batched structure even for a single prompt.
if isinstance(obj, list) and obj and isinstance(obj[0], list) and len(obj) == 1:
obj = obj[0]
return list(obj)
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
@@ -211,7 +242,8 @@ def test_rope(context_len):
# Generate reference RoPE via HF
class RoPEConfig:
rope_type = "qwen3"
# Transformers' RoPE init map does not include "qwen3".
rope_type = "default"
factor = 1.0
dim: int = head_dim
rope_theta = 1_000_000
@@ -219,6 +251,19 @@ def test_rope(context_len):
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
def __init__(self):
# Transformers >=5.0.0 expects `rope_parameters` on the instance.
self.rope_parameters = {"rope_type": "default", "rope_theta": rope_theta, "factor": 1.0}
def standardize_rope_params(self):
params = dict(getattr(self, "rope_parameters", {}) or {})
if "rope_type" not in params:
params["rope_type"] = getattr(self, "rope_type", "default")
if "rope_theta" not in params:
params["rope_theta"] = getattr(self, "rope_theta")
self.rope_parameters = params
return params
config = RoPEConfig()
rot_emb = Qwen3RotaryEmbedding(config=config)
@@ -495,12 +540,12 @@ def test_chat_wrap_and_equivalence(add_gen, add_think):
# Our encode vs HF template
ours = qt.encode(prompt)
ref = hf_tok.apply_chat_template(
ref = _hf_ids(hf_tok.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=add_gen,
enable_thinking=add_think,
)
))
if add_gen and not add_think:
pass # skip edge case as this is not something we use in practice
@@ -508,7 +553,8 @@ def test_chat_wrap_and_equivalence(add_gen, add_think):
assert ours == ref, (repo_id, add_gen, add_think)
# Round-trip decode equality
assert qt.decode(ours) == hf_tok.decode(ref)
if not (add_gen and not add_think):
assert qt.decode(ours) == hf_tok.decode(ref)
# EOS/PAD parity
assert qt.eos_token_id == hf_tok.eos_token_id
@@ -547,6 +593,7 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
messages, tokenize=True,
add_generation_prompt=add_gen, enable_thinking=add_think
)
ref_ids = _hf_ids(ref_ids)
ref_text = hf_tok.apply_chat_template(
messages, tokenize=False,
add_generation_prompt=add_gen, enable_thinking=add_think
@@ -611,6 +658,7 @@ def test_tokenizer_equivalence():
add_generation_prompt=states[0],
enable_thinking=states[1],
)
input_token_ids_ref = _hf_ids(input_token_ids_ref)
else:
input_token_ids_ref = input_token_ids
@@ -665,6 +713,7 @@ def test_multiturn_prefix_stability(repo_id, tok_file, add_gen, add_think):
running, tokenize=True,
add_generation_prompt=add_gen, enable_thinking=add_think
)
ref_ids = _hf_ids(ref_ids)
ref_text = hf_tok.apply_chat_template(
running, tokenize=False,
add_generation_prompt=add_gen, enable_thinking=add_think

View File

@@ -14,12 +14,10 @@ dependencies = [
"torch>=2.2.2; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version <= '3.12'",
"torch>=2.2.2; sys_platform == 'linux' and python_version <= '3.12'",
"torch>=2.2.2; sys_platform == 'win32' and python_version <= '3.12'",
"tensorflow>=2.16.2; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"tensorflow>=2.18.0; sys_platform == 'darwin' and platform_machine == 'arm64'",
"tensorflow>=2.18.0; sys_platform == 'linux'",
"tensorflow>=2.18.0; sys_platform == 'win32'",
"jupyterlab>=4.0",
"tiktoken>=0.5.1",
"matplotlib>=3.7.1",
@@ -53,7 +51,7 @@ bonus = [
"sentencepiece>=0.1.99",
"thop",
"tokenizers>=0.21.1",
"transformers>=4.33.2",
"transformers>=5.0.0",
"tqdm>=4.65.0",
]