mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3Tokenizer fix for Qwen3 Base models and generation mismatch with HF (#828)
* prevent `self.apply_chat_template` being applied for base Qwen models * - added no chat template comparison in `test_chat_wrap_and_equivalence` - removed duplicate comparison * Revert "- added no chat template comparison in `test_chat_wrap_and_equivalence`" This reverts commit3a5ee8cfa1. * Revert "prevent `self.apply_chat_template` being applied for base Qwen models" This reverts commitdf504397a8. * copied `download_file` in `utils` from https://github.com/rasbt/reasoning-from-scratch/blob/main/reasoning_from_scratch/utils.py * added copy of test `def test_tokenizer_equivalence()` from `reasoning-from-scratch` in `test_qwen3.py` * removed duplicate code fragment in`test_chat_wrap_and_equivalence` * use apply_chat_template * add toggle for instruct model * Update tokenizer usage --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
@@ -20,7 +20,12 @@ from llms_from_scratch.kv_cache.generate import generate_text_simple as generate
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
|
||||
from llms_from_scratch.utils import download_file
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import platform
|
||||
import pytest
|
||||
import torch
|
||||
@@ -465,13 +470,6 @@ def test_chat_wrap_and_equivalence(add_gen, add_think):
|
||||
add_generation_prompt=add_gen,
|
||||
enable_thinking=add_think,
|
||||
)
|
||||
ours = qt.encode(prompt)
|
||||
ref = hf_tok.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=add_gen,
|
||||
enable_thinking=add_think,
|
||||
)
|
||||
|
||||
if add_gen and not add_think:
|
||||
pass # skip edge case as this is not something we use in practice
|
||||
@@ -534,6 +532,72 @@ def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
|
||||
assert ours_dec == ref_dec
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_tokenizer_equivalence():
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
for apply_chat_template in (True, False):
|
||||
for s in ("-Base", ""):
|
||||
repo_id = f"Qwen/Qwen3-0.6B{s}"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
tokenizer_url = f"https://huggingface.co/Qwen/Qwen3-0.6B{s}/resolve/main/tokenizer.json"
|
||||
download_file(tokenizer_url, out_dir=".")
|
||||
|
||||
old_name = "tokenizer.json"
|
||||
|
||||
if not s:
|
||||
new_name = "tokenizer-reasoning.json"
|
||||
else:
|
||||
new_name = "tokenizer-base.json"
|
||||
|
||||
try:
|
||||
shutil.move(old_name, new_name)
|
||||
except Exception:
|
||||
with tempfile.NamedTemporaryFile(delete=False, dir=".") as tmp_file:
|
||||
shutil.copyfile(old_name, tmp_file.name)
|
||||
os.replace(tmp_file.name, new_name)
|
||||
os.remove(old_name)
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path=new_name,
|
||||
repo_id=repo_id,
|
||||
apply_chat_template=apply_chat_template,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
|
||||
if apply_chat_template:
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
)
|
||||
else:
|
||||
input_token_ids_ref = input_token_ids
|
||||
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
||||
|
||||
assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]]
|
||||
assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]]
|
||||
|
||||
expected_eos_token = "<|im_end|>" if "base" not in new_name else "<|endoftext|>"
|
||||
expected_pad_token = "<|endoftext|>"
|
||||
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
||||
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
@pytest.mark.parametrize("repo_id, tok_file", [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
||||
|
||||
@@ -9,6 +9,8 @@ import ast
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
|
||||
import nbformat
|
||||
|
||||
@@ -122,3 +124,22 @@ def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extr
|
||||
|
||||
exec(src, mod.__dict__)
|
||||
return mod
|
||||
|
||||
def download_file(url, out_dir="."):
|
||||
"""Simple file download utility for tests."""
|
||||
from pathlib import Path
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = Path(urllib.parse.urlparse(url).path).name
|
||||
dest = out_dir / filename
|
||||
|
||||
if dest.exists():
|
||||
return dest
|
||||
|
||||
try:
|
||||
with urllib.request.urlopen(url) as response:
|
||||
with open(dest, 'wb') as f:
|
||||
f.write(response.read())
|
||||
return dest
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download {url}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user