mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Interactive qwen3 chat interface (#801)
* Interactive qwen3 chat interface * update * update * update url
This commit is contained in:
committed by
GitHub
parent
70edd53809
commit
9eee9296d9
@@ -514,8 +514,9 @@ class Qwen3Tokenizer:
|
||||
"<|quad_start|>", "<|quad_end|>",
|
||||
"<|vision_start|>", "<|vision_end|>",
|
||||
"<|vision_pad|>", "<|image_pad|>", "<|video_pad|>",
|
||||
"<think>", "</think>"
|
||||
]
|
||||
_SPLIT_RE = re.compile(r"(<\|[^>]+?\|>)")
|
||||
_SPLIT_RE = re.compile(r"(<\|[^>]+?\|>|<think>|</think>)")
|
||||
|
||||
def __init__(self, tokenizer_file_path="tokenizer.json", repo_id=None,
|
||||
apply_chat_template=True, add_generation_prompt=False, add_thinking=False):
|
||||
@@ -533,9 +534,13 @@ class Qwen3Tokenizer:
|
||||
local_dir=str(tok_file.parent),
|
||||
)
|
||||
self._tok = Tokenizer.from_file(str(tok_file))
|
||||
self._special_to_id = {t: self._tok.token_to_id(t) for t in self._SPECIALS}
|
||||
self._special_to_id = {}
|
||||
for t in self._SPECIALS:
|
||||
tid = self._tok.token_to_id(t)
|
||||
if tid is not None:
|
||||
self._special_to_id[t] = tid
|
||||
|
||||
self.pad_token_id = self._special_to_id.get("<|endoftext|>")
|
||||
self.pad_token_id = self._special_to_id["<|endoftext|>"]
|
||||
self.eos_token_id = self.pad_token_id
|
||||
|
||||
if repo_id and "Base" not in repo_id:
|
||||
|
||||
@@ -383,75 +383,236 @@ def test_rmsnorm_equivalence():
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_tokenizer_equivalence():
|
||||
@pytest.mark.parametrize("repo_id, tok_file", [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"), # Chat / Reasoning
|
||||
("Qwen/Qwen3-0.6B-Base", "Qwen3-0.6B-Base/tokenizer.json"), # Base
|
||||
])
|
||||
def test_all_special_tokens_roundtrip(repo_id, tok_file):
|
||||
from transformers import AutoTokenizer as HFTokenizer
|
||||
hf_tok = HFTokenizer.from_pretrained(repo_id)
|
||||
|
||||
qt = Qwen3Tokenizer(
|
||||
tokenizer_file_path=tok_file,
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=False,
|
||||
add_thinking=False,
|
||||
)
|
||||
|
||||
# Use the instance's actually-available specials
|
||||
active_specials = list(qt._special_to_id.keys())
|
||||
|
||||
# Every available special has a concrete id and round-trips
|
||||
for sp, sp_id in qt._special_to_id.items():
|
||||
assert isinstance(sp_id, int) and sp_id >= 0, f"{sp} missing or invalid id"
|
||||
assert qt.encode(sp) == [sp_id], f"{sp} must encode to its single id"
|
||||
assert qt.decode([sp_id]) == sp, f"{sp} must decode back to itself"
|
||||
|
||||
# Inline use preserves boundaries for available specials
|
||||
for sp in active_specials:
|
||||
s = f"hello {sp} world"
|
||||
ids = qt.encode(s, chat_wrapped=False)
|
||||
sp_id = qt._special_to_id[sp]
|
||||
assert sp_id in ids, f"{sp} id not found inline"
|
||||
assert qt.decode(ids) == s, f"Inline decode mismatch for {sp}"
|
||||
|
||||
# EOS / PAD expectations
|
||||
is_base = ("Base" in repo_id)
|
||||
expected_eos = "<|endoftext|>" if is_base else "<|im_end|>"
|
||||
expected_pad = "<|endoftext|>"
|
||||
|
||||
assert qt.decode([qt.eos_token_id]) == expected_eos
|
||||
assert qt.decode([qt.pad_token_id]) == expected_pad
|
||||
assert hf_tok.eos_token_id == qt.eos_token_id
|
||||
assert hf_tok.pad_token_id == qt.pad_token_id
|
||||
assert hf_tok.decode([hf_tok.eos_token_id], skip_special_tokens=False) == expected_eos
|
||||
assert hf_tok.decode([hf_tok.pad_token_id], skip_special_tokens=False) == expected_pad
|
||||
|
||||
# Thinking tokens only on chat models
|
||||
if not is_base:
|
||||
assert qt._tok.token_to_id("<think>") == 151667
|
||||
assert qt._tok.token_to_id("</think>") == 151668
|
||||
assert qt.encode("<think>") == [151667]
|
||||
assert qt.encode("</think>") == [151668]
|
||||
else:
|
||||
assert "<think>" not in active_specials and "</think>" not in active_specials
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
@pytest.mark.parametrize("add_gen, add_think", [(True, True), (True, False), (False, False)])
|
||||
def test_chat_wrap_and_equivalence(add_gen, add_think):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
for repo_id, tok_file in [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
||||
("Qwen/Qwen3-0.6B-Base", "Qwen3-0.6B-Base/tokenizer.json"),
|
||||
]:
|
||||
hf_tok = AutoTokenizer.from_pretrained(repo_id)
|
||||
qt = Qwen3Tokenizer(
|
||||
tokenizer_file_path=tok_file,
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=add_gen,
|
||||
add_thinking=add_think,
|
||||
)
|
||||
|
||||
# Our encode vs HF template
|
||||
ours = qt.encode(prompt)
|
||||
ref = hf_tok.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=add_gen,
|
||||
enable_thinking=add_think,
|
||||
)
|
||||
ours = qt.encode(prompt)
|
||||
ref = hf_tok.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=add_gen,
|
||||
enable_thinking=add_think,
|
||||
)
|
||||
|
||||
if add_gen and not add_think:
|
||||
pass # skip edge case as this is not something we use in practice
|
||||
else:
|
||||
assert ours == ref, (repo_id, add_gen, add_think)
|
||||
|
||||
# Round-trip decode equality
|
||||
assert qt.decode(ours) == hf_tok.decode(ref)
|
||||
|
||||
# EOS/PAD parity
|
||||
assert qt.eos_token_id == hf_tok.eos_token_id
|
||||
assert qt.pad_token_id == hf_tok.pad_token_id
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
@pytest.mark.parametrize("repo_id, tok_file", [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
||||
("Qwen/Qwen3-0.6B-Base", "Qwen3-0.6B-Base/tokenizer.json"),
|
||||
])
|
||||
@pytest.mark.parametrize("add_gen, add_think", [
|
||||
(True, True),
|
||||
(False, False),
|
||||
])
|
||||
def test_multiturn_equivalence(repo_id, tok_file, add_gen, add_think):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_tok = AutoTokenizer.from_pretrained(repo_id)
|
||||
qt = Qwen3Tokenizer(
|
||||
tokenizer_file_path=tok_file,
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=add_gen,
|
||||
add_thinking=add_think,
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": prompt},
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Summarize transformers in one sentence."},
|
||||
{"role": "assistant", "content": "Transformers use attention to model long-range dependencies efficiently."},
|
||||
{"role": "user", "content": "Now add one concrete example."},
|
||||
]
|
||||
|
||||
# Reasoning model tokenizer
|
||||
repo_id = "Qwen/Qwen3-0.6B"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
# HF reference (ids and raw template text)
|
||||
ref_ids = hf_tok.apply_chat_template(
|
||||
messages, tokenize=True,
|
||||
add_generation_prompt=add_gen, enable_thinking=add_think
|
||||
)
|
||||
ref_text = hf_tok.apply_chat_template(
|
||||
messages, tokenize=False,
|
||||
add_generation_prompt=add_gen, enable_thinking=add_think
|
||||
)
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="Qwen3-0.6B/tokenizer.json",
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
# Our encode over HF's raw template text
|
||||
ours_ids = qt.encode(ref_text, chat_wrapped=False)
|
||||
|
||||
assert ours_ids == ref_ids, f"mismatch for ({repo_id}, add_gen={add_gen}, add_think={add_think})"
|
||||
|
||||
# Round-trip decode equality
|
||||
ours_dec = qt.decode(ours_ids)
|
||||
ref_dec = hf_tok.decode(ref_ids, skip_special_tokens=False)
|
||||
assert ours_dec == ref_dec
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
@pytest.mark.parametrize("repo_id, tok_file", [
|
||||
("Qwen/Qwen3-0.6B", "Qwen3-0.6B/tokenizer.json"),
|
||||
])
|
||||
@pytest.mark.parametrize("add_gen, add_think", [
|
||||
(True, True),
|
||||
(False, False),
|
||||
])
|
||||
def test_multiturn_prefix_stability(repo_id, tok_file, add_gen, add_think):
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
hf_tok = AutoTokenizer.from_pretrained(repo_id)
|
||||
qt = Qwen3Tokenizer(
|
||||
tokenizer_file_path=tok_file,
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=add_gen,
|
||||
add_thinking=add_think,
|
||||
)
|
||||
|
||||
turns = [
|
||||
[{"role": "user", "content": "Define perplexity briefly."}],
|
||||
[{"role": "assistant", "content": "A measure of how well a language model predicts a sample."}],
|
||||
[{"role": "user", "content": "And why lower is better?"}],
|
||||
]
|
||||
|
||||
prev_ids_qt, prev_ids_hf = None, None
|
||||
prev_ref_text = None
|
||||
running = [] # grows turn-by-turn
|
||||
|
||||
for delta in turns:
|
||||
running += delta
|
||||
|
||||
ref_ids = hf_tok.apply_chat_template(
|
||||
running, tokenize=True,
|
||||
add_generation_prompt=add_gen, enable_thinking=add_think
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
ref_text = hf_tok.apply_chat_template(
|
||||
running, tokenize=False,
|
||||
add_generation_prompt=add_gen, enable_thinking=add_think
|
||||
)
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
||||
# Normalize line endings to match our encoder's assumptions
|
||||
ref_text_norm = ref_text.replace("\r\n", "\n").replace("\r", "\n")
|
||||
|
||||
assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
||||
assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|
||||
# Our encode over HF’s raw template text
|
||||
ours_ids = qt.encode(ref_text_norm, chat_wrapped=False)
|
||||
|
||||
# Base model tokenizer
|
||||
repo_id = "Qwen/Qwen3-0.6B-Base"
|
||||
tokenizer_ref = AutoTokenizer.from_pretrained(repo_id)
|
||||
# 1) Exact equality per stage
|
||||
if ours_ids != ref_ids:
|
||||
# Lightweight inline diff to aid debugging
|
||||
from itertools import zip_longest
|
||||
for i, (a, b) in enumerate(zip_longest(ours_ids, ref_ids, fillvalue=None)):
|
||||
if a != b:
|
||||
slice_lo, slice_hi = max(0, i-6), i+6
|
||||
ours_slice = ours_ids[slice_lo:slice_hi]
|
||||
ref_slice = ref_ids[slice_lo:slice_hi]
|
||||
ours_toks = [qt._tok.id_to_token(x) if x is not None else None for x in ours_slice]
|
||||
ref_toks = hf_tok.convert_ids_to_tokens(ref_slice, skip_special_tokens=False)
|
||||
raise AssertionError(
|
||||
f"Stage mismatch for ({repo_id}, add_gen={add_gen}, add_think={add_think}) at index {i}\n"
|
||||
f"OURS ids: {ours_slice}\nREF ids: {ref_slice}\n"
|
||||
f"OURS tok: {ours_toks}\nREF tok: {ref_toks}\n"
|
||||
f"OURS dec: {qt.decode(ours_slice)}\nREF dec: {hf_tok.decode(ref_slice, skip_special_tokens=False)}"
|
||||
)
|
||||
# If no raise, they match
|
||||
assert ours_ids == ref_ids
|
||||
|
||||
for states in ((True, True), (False, False)):
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="Qwen3-0.6B-Base/tokenizer.json",
|
||||
repo_id=repo_id,
|
||||
add_generation_prompt=states[0],
|
||||
add_thinking=states[1]
|
||||
)
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids_ref = tokenizer_ref.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=states[0],
|
||||
enable_thinking=states[1],
|
||||
)
|
||||
assert input_token_ids == input_token_ids_ref, states
|
||||
# 2) Prefix stability only when HF's own *text* remained a prefix
|
||||
if prev_ids_hf is not None and prev_ref_text is not None:
|
||||
if ref_text.startswith(prev_ref_text):
|
||||
assert ours_ids[:len(prev_ids_qt)] == prev_ids_qt
|
||||
assert ref_ids[:len(prev_ids_hf)] == prev_ids_hf
|
||||
# else: HF modified earlier boundaries (e.g., inserted <think>), so skip prefix checks
|
||||
|
||||
output_text = tokenizer.decode(input_token_ids)
|
||||
out_text_ref = tokenizer_ref.decode(input_token_ids_ref)
|
||||
assert output_text == out_text_ref, states
|
||||
# 3) Decode parity at each step
|
||||
assert qt.decode(ours_ids) == hf_tok.decode(ref_ids, skip_special_tokens=False)
|
||||
|
||||
assert tokenizer_ref.eos_token_id == tokenizer.eos_token_id
|
||||
assert tokenizer_ref.pad_token_id == tokenizer.pad_token_id
|
||||
|
||||
assert tokenizer.encode("<|endoftext|>") == [tokenizer._special_to_id["<|endoftext|>"]]
|
||||
assert tokenizer.encode("<|im_end|>") == [tokenizer._special_to_id["<|im_end|>"]]
|
||||
|
||||
expected_eos_token = "<|im_end|>" if "Base" not in repo_id else "<|endoftext|>"
|
||||
expected_pad_token = "<|endoftext|>"
|
||||
assert tokenizer.decode([tokenizer.eos_token_id]) == expected_eos_token
|
||||
assert tokenizer.decode([tokenizer.pad_token_id]) == expected_pad_token
|
||||
prev_ids_qt, prev_ids_hf = ours_ids, ref_ids
|
||||
prev_ref_text = ref_text
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user