From e155d1b02c3ad930e81fed11daa6340f7aec289b Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 27 Jan 2026 17:44:55 -0600 Subject: [PATCH] Update unit tests for CI (#952) * Update CI * Revert submodule pointer update * Update * update * update --- .github/workflows/basic-tests-linux-uv.yml | 1 - .../tests/tests_rope_and_parts.py | 12 ++++ pkg/llms_from_scratch/tests/test_llama3.py | 21 ++++++- pkg/llms_from_scratch/tests/test_qwen3.py | 57 +++++++++++++++++-- pyproject.toml | 4 +- 5 files changed, 86 insertions(+), 9 deletions(-) diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index 67d886b..d93a213 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -79,5 +79,4 @@ jobs: shell: bash run: | source .venv/bin/activate - uv pip install transformers pytest pkg/llms_from_scratch/tests/ diff --git a/ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py b/ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py index 22e00e9..3ff3694 100644 --- a/ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py +++ b/ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py @@ -177,6 +177,10 @@ def test_rope_llama2(notebook): max_position_embeddings: int = 8192 hidden_size = head_dim * num_heads num_attention_heads = num_heads + rope_parameters = {"rope_type": "default", "rope_theta": theta_base} + + def standardize_rope_params(self): + return config = RoPEConfig() rot_emb = LlamaRotaryEmbedding(config=config) @@ -242,6 +246,10 @@ def test_rope_llama3(notebook): max_position_embeddings: int = 8192 hidden_size = head_dim * num_heads num_attention_heads = num_heads + rope_parameters = {"rope_type": "default", "rope_theta": theta_base} + + def standardize_rope_params(self): + return config = RoPEConfig() rot_emb = LlamaRotaryEmbedding(config=config) @@ -320,6 +328,10 @@ def test_rope_llama3_12(notebook): max_position_embeddings: int = 8192 hidden_size = head_dim * num_heads num_attention_heads = num_heads + rope_parameters = {**hf_rope_params, "rope_theta": rope_theta} + + def standardize_rope_params(self): + return config = RoPEConfig() diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py index a3c3f69..799bc74 100644 --- a/pkg/llms_from_scratch/tests/test_llama3.py +++ b/pkg/llms_from_scratch/tests/test_llama3.py @@ -111,6 +111,25 @@ def test_rope(): hidden_size = head_dim * num_heads num_attention_heads = num_heads + def __init__(self): + # Transformers >=5.0.0 expects `rope_parameters` on the instance. + self.rope_parameters = {**hf_rope_params, "rope_theta": rope_theta} + + def standardize_rope_params(self): + params = dict(getattr(self, "rope_parameters", {}) or {}) + if "rope_type" not in params: + params["rope_type"] = getattr(self, "rope_type", "default") + if "rope_theta" not in params: + params["rope_theta"] = getattr(self, "rope_theta") + # Handle older key name used in this repo. + if ( + "original_max_position_embeddings" not in params + and "original_context_length" in params + ): + params["original_max_position_embeddings"] = params["original_context_length"] + self.rope_parameters = params + return params + config = RoPEConfig() rot_emb = LlamaRotaryEmbedding(config=config) @@ -304,4 +323,4 @@ def test_llama3_base_equivalence_with_transformers(): ours_logits = ours(x) theirs_logits = theirs(x).logits.to(ours_logits.dtype) - torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) \ No newline at end of file + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 68dd901..2789274 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -28,6 +28,7 @@ import os import shutil import tempfile import platform +from collections.abc import Mapping import pytest import torch import torch.nn as nn @@ -59,6 +60,36 @@ class Qwen3RMSNorm(nn.Module): transformers_installed = importlib.util.find_spec("transformers") is not None +def _hf_ids(obj): + """Normalize HF chat-template outputs across Transformers versions.""" + if isinstance(obj, Mapping): + if "input_ids" in obj: + obj = obj["input_ids"] + elif "ids" in obj: + obj = obj["ids"] + elif hasattr(obj, "keys") and hasattr(obj, "__getitem__"): + # Some HF containers behave like mappings but don't register as Mapping. + try: + if "input_ids" in obj: + obj = obj["input_ids"] + elif "ids" in obj: + obj = obj["ids"] + except Exception: + pass + if hasattr(obj, "input_ids"): + obj = obj.input_ids + if hasattr(obj, "ids"): + obj = obj.ids + if isinstance(obj, torch.Tensor): + obj = obj.tolist() + if isinstance(obj, tuple): + obj = list(obj) + # Some HF versions return a batched structure even for a single prompt. + if isinstance(obj, list) and obj and isinstance(obj[0], list) and len(obj) == 1: + obj = obj[0] + return list(obj) + + @pytest.fixture def dummy_input(): torch.manual_seed(123) @@ -211,7 +242,8 @@ def test_rope(context_len): # Generate reference RoPE via HF class RoPEConfig: - rope_type = "qwen3" + # Transformers' RoPE init map does not include "qwen3". + rope_type = "default" factor = 1.0 dim: int = head_dim rope_theta = 1_000_000 @@ -219,6 +251,19 @@ def test_rope(context_len): hidden_size = head_dim * num_heads num_attention_heads = num_heads + def __init__(self): + # Transformers >=5.0.0 expects `rope_parameters` on the instance. + self.rope_parameters = {"rope_type": "default", "rope_theta": rope_theta, "factor": 1.0} + + def standardize_rope_params(self): + params = dict(getattr(self, "rope_parameters", {}) or {}) + if "rope_type" not in params: + params["rope_type"] = getattr(self, "rope_type", "default") + if "rope_theta" not in params: + params["rope_theta"] = getattr(self, "rope_theta") + self.rope_parameters = params + return params + config = RoPEConfig() rot_emb = Qwen3RotaryEmbedding(config=config) @@ -495,12 +540,12 @@ def test_chat_wrap_and_equivalence(add_gen, add_think): # Our encode vs HF template ours = qt.encode(prompt) - ref = hf_tok.apply_chat_template( + ref = _hf_ids(hf_tok.apply_chat_template( messages, tokenize=True, add_generation_prompt=add_gen, enable_thinking=add_think, - ) + )) if add_gen and not add_think: pass # skip edge case as this is not something we use in practice @@ -508,7 +553,8 @@ def test_chat_wrap_and_equivalence(add_gen, add_think): assert ours == ref, (repo_id, add_gen, add_think) # Round-trip decode equality - assert qt.decode(ours) == hf_tok.decode(ref) + if not (add_gen and not add_think): + assert qt.decode(ours) == hf_tok.decode(ref) # EOS/PAD parity assert qt.eos_token_id == hf_tok.eos_token_id @@ -547,6 +593,7 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think): messages, tokenize=True, add_generation_prompt=add_gen, enable_thinking=add_think ) + ref_ids = _hf_ids(ref_ids) ref_text = hf_tok.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_gen, enable_thinking=add_think @@ -611,6 +658,7 @@ def test_tokenizer_equivalence(): add_generation_prompt=states[0], enable_thinking=states[1], ) + input_token_ids_ref = _hf_ids(input_token_ids_ref) else: input_token_ids_ref = input_token_ids @@ -665,6 +713,7 @@ def test_multiturn_prefix_stability(repo_id, tok_file, add_gen, add_think): running, tokenize=True, add_generation_prompt=add_gen, enable_thinking=add_think ) + ref_ids = _hf_ids(ref_ids) ref_text = hf_tok.apply_chat_template( running, tokenize=False, add_generation_prompt=add_gen, enable_thinking=add_think diff --git a/pyproject.toml b/pyproject.toml index 37d2cc7..913af59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,12 +14,10 @@ dependencies = [ "torch>=2.2.2; sys_platform == 'darwin' and platform_machine == 'arm64' and python_version <= '3.12'", "torch>=2.2.2; sys_platform == 'linux' and python_version <= '3.12'", "torch>=2.2.2; sys_platform == 'win32' and python_version <= '3.12'", - "tensorflow>=2.16.2; sys_platform == 'darwin' and platform_machine == 'x86_64'", "tensorflow>=2.18.0; sys_platform == 'darwin' and platform_machine == 'arm64'", "tensorflow>=2.18.0; sys_platform == 'linux'", "tensorflow>=2.18.0; sys_platform == 'win32'", - "jupyterlab>=4.0", "tiktoken>=0.5.1", "matplotlib>=3.7.1", @@ -53,7 +51,7 @@ bonus = [ "sentencepiece>=0.1.99", "thop", "tokenizers>=0.21.1", - "transformers>=4.33.2", + "transformers>=5.0.0", "tqdm>=4.65.0", ]