diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 78555c0..901700d 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -457,21 +457,19 @@ def test_chat_wrap_and_equivalence(add_gen, add_think): add_thinking=add_think, ) - # Our encode vs HF template - ours = qt.encode(prompt) - ref = hf_tok.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=add_gen, - enable_thinking=add_think, - ) - ours = qt.encode(prompt) - ref = hf_tok.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=add_gen, - enable_thinking=add_think, - ) + # Base models: compare raw encoding (no chat template) + if "Base" in repo_id: + ours = qt.encode(prompt) # should use no chat template + ref = hf_tok.encode(prompt) # raw encoding without chat template + else: + # Instruct models: compare with chat template + ours = qt.encode(prompt) # will use chat template + ref = hf_tok.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=add_gen, + enable_thinking=add_think, + ) if add_gen and not add_think: pass # skip edge case as this is not something we use in practice