mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
added copy of test def test_tokenizer_equivalence() from reasoning-from-scratch in test_qwen3.py
This commit is contained in:
@@ -20,7 +20,12 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
|
||||
from llms_from_scratch.utils import download_file
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import platform
|
||||
import pytest
|
||||
import torch
|
||||
@@ -534,6 +539,72 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
|
||||
assert ours_dec == ref_dec
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_tokenizer_equivalence():
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
for apply_chat_template in (True, False):
|
||||
for s in ("-Base", ""):
|
||||
repo_id = f"Qwen/Qwen3-0.6B{s}"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
tokenizer_url = f"https://huggingface.co/Qwen/Qwen3-0.6B{s}/resolve/main/tokenizer.json"
|
||||
download_file(tokenizer_url, out_dir=".")
|
||||
|
||||
old_name = "tokenizer.json"
|
||||
|
||||
if not s:
|
||||
new_name = "tokenizer-reasoning.json"
|
||||
else:
|
||||
new_name = "tokenizer-base.json"
|
||||
|
||||
try:
|
||||
shutil.move(old_name, new_name)
|
||||
except Exception:
|
||||
with tempfile.NamedTemporaryFile(delete=False, dir=".") as tmp_file:
|
||||
shutil.copyfile(old_name, tmp_file.name)
|
||||
os.replace(tmp_file.name, new_name)
|
||||
os.remove(old_name)
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path=new_name,
|
||||
repo_id=repo_id,
|
||||
apply_chat_template=apply_chat_template,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
|
||||
if apply_chat_template:
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
)
|
||||
else:
|
||||
input_token_ids_ref = input_token_ids
|
||||
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
||||
|
||||
assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]]
|
||||
assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]]
|
||||
|
||||
expected_eos_token = "<|im_end|>" if "base" not in new_name else "<|endoftext|>"
|
||||
expected_pad_token = "<|endoftext|>"
|
||||
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
||||
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
@pytest.mark.parametrize("repo_id, tok_file", [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
||||
|
||||
Reference in New Issue
Block a user