added copy of test def test_tokenizer_equivalence() from reasoning-from-scratch in test_qwen3.py

This commit is contained in:
casinca
2025-09-16 11:12:29 +02:00
parent 4ea2fb4a76
commit 16f30a0395

View File

@@ -20,7 +20,12 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
from llms_from_scratch.utils import download_file
import importlib
import os
import shutil
import tempfile
import platform
import pytest
import torch
@@ -534,6 +539,72 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
assert ours_dec == ref_dec
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_tokenizer_equivalence():
from transformers import AutoTokenizer
prompt = "Give me a short introduction to large language models."
messages = [
{"role": "user", "content": prompt},
]
for apply_chat_template in (True, False):
for s in ("-Base", ""):
repo_id = f"Qwen/Qwen3-0.6B{s}"
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
tokenizer_url = f"https://huggingface.co/Qwen/Qwen3-0.6B{s}/resolve/main/tokenizer.json"
download_file(tokenizer_url, out_dir=".")
old_name = "tokenizer.json"
if not s:
new_name = "tokenizer-reasoning.json"
else:
new_name = "tokenizer-base.json"
try:
shutil.move(old_name, new_name)
except Exception:
with tempfile.NamedTemporaryFile(delete=False, dir=".") as tmp_file:
shutil.copyfile(old_name, tmp_file.name)
os.replace(tmp_file.name, new_name)
os.remove(old_name)
for states in ((True, True), (False, False)):
tokenizer = Qwen3Tokenizer(
tokenizer_file_path=new_name,
repo_id=repo_id,
apply_chat_template=apply_chat_template,
add_generation_prompt=states[0],
add_thinking=states[1]
)
input_token_ids = tokenizer.encode(prompt)
if apply_chat_template:
input_token_ids_ref = tokenizer_ref.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=states[0],
enable_thinking=states[1],
)
else:
input_token_ids_ref = input_token_ids
assert input_token_ids == input_token_ids_ref, states
output_text = tokenizer.decode(input_token_ids)
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
assert output_text == out_text_ref, states
assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]]
assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]]
expected_eos_token = "<|im_end|>" if "base" not in new_name else "<|endoftext|>"
expected_pad_token = "<|endoftext|>"
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
@pytest.mark.parametrize("repo_id, tok_file", [
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),