diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 901700d..78555c0 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -457,19 +457,21 @@ def test_chat_wrap_and_equivalence(add_gen, add_think): add_thinking=add_think, ) - # Base models: compare raw encoding (no chat template) - if "Base" in repo_id: - ours = qt.encode(prompt) # should use no chat template - ref = hf_tok.encode(prompt) # raw encoding without chat template - else: - # Instruct models: compare with chat template - ours = qt.encode(prompt) # will use chat template - ref = hf_tok.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=add_gen, - enable_thinking=add_think, - ) + # Our encode vs HF template + ours = qt.encode(prompt) + ref = hf_tok.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=add_gen, + enable_thinking=add_think, + ) + ours = qt.encode(prompt) + ref = hf_tok.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=add_gen, + enable_thinking=add_think, + ) if add_gen and not add_think: pass # skip edge case as this is not something we use in practice