diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 78555c0..fccaee9 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -20,7 +20,12 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched +from llms_from_scratch.utils import download_file + import importlib +import os +import shutil +import tempfile import platform import pytest import torch @@ -534,6 +539,72 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think): assert ours_dec == ref_dec +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_tokenizer_equivalence(): + from transformers import AutoTokenizer + + prompt = "Give me a short introduction to large language models." + messages = [ + {"role": "user", "content": prompt}, + ] + + for apply_chat_template in (True, False): + for s in ("-Base", ""): + repo_id = f"Qwen/Qwen3-0.6B{s}" + tokenizer_ref = AutoTokenizer.from_pretrained(repo_id) + tokenizer_url = f"https://huggingface.co/Qwen/Qwen3-0.6B{s}/resolve/main/tokenizer.json" + download_file(tokenizer_url, out_dir=".") + + old_name = "tokenizer.json" + + if not s: + new_name = "tokenizer-reasoning.json" + else: + new_name = "tokenizer-base.json" + + try: + shutil.move(old_name, new_name) + except Exception: + with tempfile.NamedTemporaryFile(delete=False, dir=".") as tmp_file: + shutil.copyfile(old_name, tmp_file.name) + os.replace(tmp_file.name, new_name) + os.remove(old_name) + + for states in ((True, True), (False, False)): + tokenizer = Qwen3Tokenizer( + tokenizer_file_path=new_name, + repo_id=repo_id, + apply_chat_template=apply_chat_template, + add_generation_prompt=states[0], + add_thinking=states[1] + ) + input_token_ids = tokenizer.encode(prompt) + + if apply_chat_template: + input_token_ids_ref = tokenizer_ref.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=states[0], + enable_thinking=states[1], + ) + else: + input_token_ids_ref = input_token_ids + + assert input_token_ids == input_token_ids_ref, states + + output_text = tokenizer.decode(input_token_ids) + out_text_ref = tokenizer_ref.decode(input_token_ids_ref) + assert output_text == out_text_ref, states + + assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]] + assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]] + + expected_eos_token = "<|im_end|>" if "base" not in new_name else "<|endoftext|>" + expected_pad_token = "<|endoftext|>" + assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token + assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token + + @pytest.mark.skipif(not transformers_installed, reason="transformers not installed") @pytest.mark.parametrize("repo_id, tok_file", [ ("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),