Some gemma 3 improvements (#1000)

* some gemma 3 improvements

* update url
This commit is contained in:
Sebastian Raschka
2026-04-05 22:05:05 -04:00
committed by GitHub
parent afc6a3da07
commit 8447d70b18
7 changed files with 825 additions and 48 deletions

View File

@@ -1137,20 +1137,54 @@
" def __init__(self, tokenizer_file_path: str):\n",
" tok_file = Path(tokenizer_file_path)\n",
" self._tok = Tokenizer.from_file(str(tok_file))\n",
" # Attempt to identify EOS and padding tokens\n",
" eos_token = \"<end_of_turn>\"\n",
" self.pad_token_id = eos_token\n",
" self.eos_token_id = eos_token\n",
"\n",
" def encode(self, text: str) -> list[int]:\n",
" return self._tok.encode(text).ids\n",
" self.bos_token = \"<bos>\"\n",
" self.eos_token = \"<eos>\"\n",
" self.pad_token = \"<pad>\"\n",
" self.start_of_turn_token = \"<start_of_turn>\"\n",
" self.end_of_turn_token = \"<end_of_turn>\"\n",
"\n",
" def decode(self, ids: list[int]) -> str:\n",
" return self._tok.decode(ids, skip_special_tokens=False)\n",
" self.bos_token_id = self._tok.token_to_id(self.bos_token)\n",
" self.eos_token_id = self._tok.token_to_id(self.eos_token)\n",
" self.pad_token_id = self._tok.token_to_id(self.pad_token)\n",
" self.start_of_turn_token_id = self._tok.token_to_id(self.start_of_turn_token)\n",
" self.end_of_turn_token_id = self._tok.token_to_id(self.end_of_turn_token)\n",
"\n",
" self.add_bos_token = True\n",
" self.add_eos_token = False\n",
" self.clean_up_tokenization_spaces = False\n",
"\n",
" def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:\n",
" return self._tok.encode(text, add_special_tokens=add_special_tokens).ids\n",
"\n",
" def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:\n",
" if isinstance(ids, int):\n",
" ids = [ids]\n",
" return self._tok.decode(ids, skip_special_tokens=skip_special_tokens)\n",
"\n",
" def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):\n",
" text = \"\"\n",
" for message in messages:\n",
" role = message[\"role\"]\n",
" if role == \"assistant\":\n",
" role = \"model\"\n",
" content = message[\"content\"]\n",
" text += f\"{self.start_of_turn_token}{role}\\n{content}{self.end_of_turn_token}\\n\"\n",
"\n",
" if add_generation_prompt:\n",
" text += f\"{self.start_of_turn_token}model\\n\"\n",
"\n",
" if tokenize:\n",
" return self.encode(text)\n",
" return text\n",
"\n",
"\n",
"def apply_chat_template(user_text):\n",
" return f\"<start_of_turn>user\\n{user_text}<end_of_turn>\\n<start_of_turn>model\\n\""
" return tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": user_text}],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )"
]
},
{
@@ -1205,7 +1239,11 @@
],
"source": [
"prompt = \"Give me a short introduction to large language models.\"\n",
"prompt = apply_chat_template(\"Give me a short introduction to large language models.\")\n",
"prompt = tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": prompt}],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"\n",
"input_token_ids = tokenizer.encode(prompt)\n",
@@ -1297,7 +1335,7 @@
" model=model,\n",
" token_ids=input_token_ids_tensor,\n",
" max_new_tokens=500,\n",
" eos_token_id=tokenizer.encode(\"<end_of_turn>\")[-1]\n",
" eos_token_id=tokenizer.end_of_turn_token_id\n",
"):\n",
" token_id = token.squeeze(0).tolist()\n",
" print(\n",

View File

@@ -77,9 +77,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.35.0\n",
"tokenizers version: 0.22.1\n",
"torch version: 2.9.0+cu130\n"
"huggingface_hub version: 1.3.4\n",
"tokenizers version: 0.22.2\n",
"torch version: 2.10.0\n"
]
}
],
@@ -627,9 +627,9 @@
{
"data": {
"text/plain": [
"tensor([[[ 0.7500, 0.1011, 0.4863, ..., 0.9414, 0.3984, -0.2285],\n",
" [-0.3398, -0.0564, 0.9023, ..., -0.2480, 0.4551, 0.8203],\n",
" [-0.2695, -0.3242, 0.4121, ..., 0.8672, -0.9688, 0.9844]]],\n",
"tensor([[[ 0.7500, 0.1060, 0.4844, ..., 0.9414, 0.3984, -0.2324],\n",
" [-0.3438, -0.0549, 0.8984, ..., -0.2402, 0.4570, 0.8242],\n",
" [-0.2676, -0.3281, 0.4121, ..., 0.8711, -0.9648, 0.9844]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
@@ -730,20 +730,7 @@
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n",
" Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n",
" Minimum and Maximum cuda capability supported by this version of PyTorch is\n",
" (8.0) - (12.0)\n",
" \n",
" warnings.warn(\n"
]
}
],
"outputs": [],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
@@ -1022,20 +1009,54 @@
" def __init__(self, tokenizer_file_path: str):\n",
" tok_file = Path(tokenizer_file_path)\n",
" self._tok = Tokenizer.from_file(str(tok_file))\n",
" # Attempt to identify EOS and padding tokens\n",
" eos_token = \"<end_of_turn>\"\n",
" self.pad_token_id = eos_token\n",
" self.eos_token_id = eos_token\n",
"\n",
" def encode(self, text: str) -> list[int]:\n",
" return self._tok.encode(text).ids\n",
" self.bos_token = \"<bos>\"\n",
" self.eos_token = \"<eos>\"\n",
" self.pad_token = \"<pad>\"\n",
" self.start_of_turn_token = \"<start_of_turn>\"\n",
" self.end_of_turn_token = \"<end_of_turn>\"\n",
"\n",
" def decode(self, ids: list[int]) -> str:\n",
" return self._tok.decode(ids, skip_special_tokens=False)\n",
" self.bos_token_id = self._tok.token_to_id(self.bos_token)\n",
" self.eos_token_id = self._tok.token_to_id(self.eos_token)\n",
" self.pad_token_id = self._tok.token_to_id(self.pad_token)\n",
" self.start_of_turn_token_id = self._tok.token_to_id(self.start_of_turn_token)\n",
" self.end_of_turn_token_id = self._tok.token_to_id(self.end_of_turn_token)\n",
"\n",
" self.add_bos_token = True\n",
" self.add_eos_token = False\n",
" self.clean_up_tokenization_spaces = False\n",
"\n",
" def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:\n",
" return self._tok.encode(text, add_special_tokens=add_special_tokens).ids\n",
"\n",
" def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:\n",
" if isinstance(ids, int):\n",
" ids = [ids]\n",
" return self._tok.decode(ids, skip_special_tokens=skip_special_tokens)\n",
"\n",
" def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):\n",
" text = \"\"\n",
" for message in messages:\n",
" role = message[\"role\"]\n",
" if role == \"assistant\":\n",
" role = \"model\"\n",
" content = message[\"content\"]\n",
" text += f\"{self.start_of_turn_token}{role}\\n{content}{self.end_of_turn_token}\\n\"\n",
"\n",
" if add_generation_prompt:\n",
" text += f\"{self.start_of_turn_token}model\\n\"\n",
"\n",
" if tokenize:\n",
" return self.encode(text)\n",
" return text\n",
"\n",
"\n",
"def apply_chat_template(user_text):\n",
" return f\"<start_of_turn>user\\n{user_text}<end_of_turn>\\n<start_of_turn>model\\n\""
" return tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": user_text}],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
" )"
]
},
{
@@ -1075,7 +1096,11 @@
],
"source": [
"prompt = \"Give me a short introduction to large language models.\"\n",
"prompt = apply_chat_template(\"Give me a short introduction to large language models.\")\n",
"prompt = tokenizer.apply_chat_template(\n",
" [{\"role\": \"user\", \"content\": prompt}],\n",
" tokenize=False,\n",
" add_generation_prompt=True,\n",
")\n",
"\n",
"\n",
"input_token_ids = tokenizer.encode(prompt)\n",
@@ -1107,7 +1132,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 24,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1133,7 +1158,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 25,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1143,10 +1168,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n",
"\n",
"\n",
"GPU memory used: 1.04 GB\n"
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
]
}
],
@@ -1162,7 +1184,7 @@
" model=model,\n",
" token_ids=input_token_ids_tensor,\n",
" max_new_tokens=500,\n",
" eos_token_id=tokenizer.encode(\"<end_of_turn>\")[-1]\n",
" eos_token_id=tokenizer.end_of_turn_token_id\n",
"):\n",
" token_id = token.squeeze(0).tolist()\n",
" print(\n",

View File

@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "book-header",
"metadata": {},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "title-cell",
"metadata": {},
"source": [
"# Gemma 3 270M With Hugging Face Transformers"
]
},
{
"cell_type": "markdown",
"id": "intro-cell",
"metadata": {},
"source": [
"- This notebook uses the minimal `AutoTokenizer` / `AutoModelForCausalLM` workflow from the Transformers tutorials.\n",
"- It uses the same user prompt as [standalone-gemma3.ipynb](../standalone-gemma3.ipynb): `Give me a short introduction to large language models.`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "install-cell",
"metadata": {},
"outputs": [],
"source": [
"# pip install transformers sentencepiece"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "login-cell",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment and run the following code if you are executing the notebook for the first time\n",
"\n",
"# from huggingface_hub import login\n",
"# login()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "load-cell",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c3b335b4a1da4658b90e1ef960de8b49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading weights: 0%| | 0/236 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_id = \"google/gemma-3-270m-it\"\n",
"prompt = \"Give me a short introduction to large language models.\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"model = AutoModelForCausalLM.from_pretrained(model_id)\n",
"model.generation_config.do_sample = False\n",
"model.generation_config.top_p = None\n",
"model.generation_config.top_k = None\n",
"model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
"model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "generate-cell",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n",
"\n"
]
}
],
"source": [
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
"\n",
"inputs = tokenizer.apply_chat_template(\n",
" messages,\n",
" tokenize=True,\n",
" add_generation_prompt=True,\n",
" return_tensors=\"pt\",\n",
")\n",
"\n",
"outputs = model.generate(\n",
" **inputs,\n",
" max_new_tokens=500,\n",
" do_sample=False,\n",
" num_beams=1,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
")\n",
"\n",
"response = tokenizer.decode(\n",
" outputs[0][inputs[\"input_ids\"].shape[-1]:],\n",
" skip_special_tokens=True,\n",
")\n",
"print(response)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -0,0 +1,232 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import importlib
from pathlib import Path
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
try:
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
except ImportError:
Gemma3ForCausalLM = None
Gemma3TextConfig = None
def tiny_debug_config():
return {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"rope_local_base": 10_000.0,
"sliding_window": 4,
"layer_types": ["full_attention", "full_attention"],
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
def _hf_config_from_dict(cfg):
if Gemma3TextConfig is None:
raise ImportError("transformers is required for the Gemma 3 debugger.")
return Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
rope_local_base_freq=cfg["rope_local_base"],
layer_types=cfg["layer_types"],
sliding_window=cfg["sliding_window"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=cfg.get("dtype", torch.float32),
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
rope_scaling={"rope_type": "default"},
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
)
def load_notebook_defs(nb_name="standalone-gemma3.ipynb"):
nb_dir = Path(__file__).resolve().parents[1]
return import_definitions_from_notebook(nb_dir, nb_name)
def build_gemma3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
if Gemma3ForCausalLM is None:
raise ImportError("transformers is required for the Gemma 3 debugger.")
ours = import_notebook_defs.Gemma3Model(cfg)
if hf_checkpoint:
hf_model = Gemma3ForCausalLM.from_pretrained(
hf_checkpoint,
torch_dtype=cfg.get("dtype", torch.float32),
attn_implementation="eager",
)
else:
hf_cfg = _hf_config_from_dict(cfg)
hf_model = Gemma3ForCausalLM(hf_cfg)
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
import_notebook_defs.load_weights_into_gemma(ours, param_config, hf_model.state_dict())
ours.eval()
hf_model.eval()
return ours, hf_model
def _attach_debug_hooks(model, is_hf):
traces = {}
handles = []
def hook(name, scale=None):
def _record(_, __, output):
if isinstance(output, tuple):
output = output[0]
if scale is not None:
output = output * scale
traces[name] = output.detach().to(torch.float32).cpu()
return _record
if is_hf:
core = model.model
handles.append(core.embed_tokens.register_forward_hook(hook("embedding")))
for idx, layer in enumerate(core.layers):
handles.append(layer.register_forward_hook(hook(f"block_{idx}")))
handles.append(core.norm.register_forward_hook(hook("final_norm")))
handles.append(model.lm_head.register_forward_hook(hook("logits")))
else:
emb_scale = float(getattr(model, "cfg", {}).get("emb_dim", model.tok_emb.embedding_dim) ** 0.5)
handles.append(model.tok_emb.register_forward_hook(hook("embedding", scale=emb_scale)))
blocks = getattr(model, "blocks", None)
if blocks is None:
blocks = getattr(model, "trf_blocks", None)
if blocks is None:
raise AttributeError("Could not locate Gemma 3 blocks on the local model.")
for idx, block in enumerate(blocks):
handles.append(block.register_forward_hook(hook(f"block_{idx}")))
handles.append(model.final_norm.register_forward_hook(hook("final_norm")))
handles.append(model.out_head.register_forward_hook(hook("logits")))
return traces, handles
def _layer_sort_key(name):
if name == "embedding":
return (0, 0)
if name.startswith("block_"):
idx = int(name.split("_")[1])
return (1, idx)
if name == "final_norm":
return (2, 0)
if name == "logits":
return (3, 0)
return (4, name)
def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5):
ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False)
hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True)
try:
with torch.inference_mode():
ours(input_ids)
hf_model(input_ids)
finally:
for h in ours_handles + hf_handles:
h.remove()
layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key)
results = []
for name in layer_names:
ours_tensor = ours_traces.get(name)
hf_tensor = hf_traces.get(name)
if ours_tensor is None or hf_tensor is None:
results.append(
{
"name": name,
"status": "missing",
"ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape),
"hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
if ours_tensor.shape != hf_tensor.shape:
results.append(
{
"name": name,
"status": "shape_mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
diff = (ours_tensor - hf_tensor).abs()
max_diff = float(diff.max().item())
mean_diff = float(diff.mean().item())
allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol)
results.append(
{
"name": name,
"status": "ok" if allclose else "mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": max_diff,
"mean_abs_diff": mean_diff,
}
)
return results
def format_report(differences):
lines = []
for diff in sorted(differences, key=lambda d: _layer_sort_key(d["name"])):
if diff["status"] == "ok":
lines.append(f"[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
elif diff["status"] == "mismatch":
lines.append(f"[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
elif diff["status"] == "shape_mismatch":
lines.append(f"[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
else:
lines.append(f"[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
return "\n".join(lines)
if __name__ == "__main__":
transformers_available = importlib.util.find_spec("transformers") is not None
if not transformers_available:
raise SystemExit("transformers is not installed; install it to run the debugger.")
import_notebook_defs = load_notebook_defs()
cfg = tiny_debug_config()
ours_model, hf_model = build_gemma3_pair(import_notebook_defs, cfg)
torch.manual_seed(0)
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
diffs = layerwise_differences(ours_model, hf_model, input_ids)
print(format_report(diffs))

View File

@@ -0,0 +1,322 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import importlib
from pathlib import Path
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
try:
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
except ImportError:
Gemma3ForCausalLM = None
Gemma3TextConfig = None
def tiny_debug_config():
return {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"rope_local_base": 10_000.0,
"sliding_window": 4,
"layer_types": ["full_attention", "full_attention"],
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
def _hf_config_from_dict(cfg):
if Gemma3TextConfig is None:
raise ImportError("transformers is required for the Gemma 3 debugger.")
return Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
rope_local_base_freq=cfg["rope_local_base"],
layer_types=cfg["layer_types"],
sliding_window=cfg["sliding_window"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=cfg.get("dtype", torch.float32),
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
rope_scaling={"rope_type": "default"},
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
)
def load_notebook_defs(nb_name="standalone-gemma3.ipynb"):
nb_dir = Path(__file__).resolve().parents[1]
return import_definitions_from_notebook(nb_dir, nb_name)
def build_gemma3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
if Gemma3ForCausalLM is None:
raise ImportError("transformers is required for the Gemma 3 debugger.")
torch.manual_seed(123)
ours = import_notebook_defs.Gemma3Model(cfg)
if hf_checkpoint:
hf_model = Gemma3ForCausalLM.from_pretrained(
hf_checkpoint,
torch_dtype=cfg.get("dtype", torch.float32),
attn_implementation="eager",
)
else:
hf_cfg = _hf_config_from_dict(cfg)
hf_model = Gemma3ForCausalLM(hf_cfg)
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
import_notebook_defs.load_weights_into_gemma(ours, param_config, hf_model.state_dict())
ours.eval()
hf_model.eval()
return ours, hf_model
def _register_trace_hook(handles, traces, name, module, scale=None):
if module is None:
return
def _record(_, __, output):
if isinstance(output, tuple):
output = output[0]
if scale is not None:
output = output * scale
traces[name] = output.detach().to(torch.float32).cpu()
handles.append(module.register_forward_hook(_record))
def _attach_debug_hooks(model, is_hf, include_block_details=True):
traces = {}
handles = []
if is_hf:
core = model.model
_register_trace_hook(handles, traces, "embedding", core.embed_tokens)
for idx, layer in enumerate(core.layers):
block_name = f"block_{idx}"
_register_trace_hook(handles, traces, block_name, layer)
if include_block_details:
_register_trace_hook(handles, traces, f"{block_name}.input_layernorm", layer.input_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.att.q_proj", layer.self_attn.q_proj)
_register_trace_hook(handles, traces, f"{block_name}.att.k_proj", layer.self_attn.k_proj)
_register_trace_hook(handles, traces, f"{block_name}.att.v_proj", layer.self_attn.v_proj)
_register_trace_hook(handles, traces, f"{block_name}.att.q_norm", layer.self_attn.q_norm)
_register_trace_hook(handles, traces, f"{block_name}.att.k_norm", layer.self_attn.k_norm)
_register_trace_hook(handles, traces, f"{block_name}.att.o_proj", layer.self_attn.o_proj)
_register_trace_hook(handles, traces, f"{block_name}.att", layer.self_attn)
_register_trace_hook(handles, traces, f"{block_name}.post_attention_layernorm", layer.post_attention_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.pre_feedforward_layernorm", layer.pre_feedforward_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.ff.gate_proj", layer.mlp.gate_proj)
_register_trace_hook(handles, traces, f"{block_name}.ff.up_proj", layer.mlp.up_proj)
_register_trace_hook(handles, traces, f"{block_name}.ff.down_proj", layer.mlp.down_proj)
_register_trace_hook(handles, traces, f"{block_name}.ff", layer.mlp)
_register_trace_hook(handles, traces, f"{block_name}.post_feedforward_layernorm", layer.post_feedforward_layernorm)
_register_trace_hook(handles, traces, "final_norm", core.norm)
_register_trace_hook(handles, traces, "logits", model.lm_head)
else:
emb_scale = float(getattr(model, "cfg", {}).get("emb_dim", model.tok_emb.embedding_dim) ** 0.5)
_register_trace_hook(handles, traces, "embedding", model.tok_emb, scale=emb_scale)
blocks = getattr(model, "blocks", None)
if blocks is None:
blocks = getattr(model, "trf_blocks", None)
if blocks is None:
raise AttributeError("Could not locate Gemma 3 blocks on the local model.")
for idx, block in enumerate(blocks):
block_name = f"block_{idx}"
_register_trace_hook(handles, traces, block_name, block)
if include_block_details:
_register_trace_hook(handles, traces, f"{block_name}.input_layernorm", block.input_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.att.q_proj", block.att.W_query)
_register_trace_hook(handles, traces, f"{block_name}.att.k_proj", block.att.W_key)
_register_trace_hook(handles, traces, f"{block_name}.att.v_proj", block.att.W_value)
_register_trace_hook(handles, traces, f"{block_name}.att.q_norm", block.att.q_norm)
_register_trace_hook(handles, traces, f"{block_name}.att.k_norm", block.att.k_norm)
_register_trace_hook(handles, traces, f"{block_name}.att.o_proj", block.att.out_proj)
_register_trace_hook(handles, traces, f"{block_name}.att", block.att)
_register_trace_hook(handles, traces, f"{block_name}.post_attention_layernorm", block.post_attention_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.pre_feedforward_layernorm", block.pre_feedforward_layernorm)
_register_trace_hook(handles, traces, f"{block_name}.ff.gate_proj", block.ff.fc1)
_register_trace_hook(handles, traces, f"{block_name}.ff.up_proj", block.ff.fc2)
_register_trace_hook(handles, traces, f"{block_name}.ff.down_proj", block.ff.fc3)
_register_trace_hook(handles, traces, f"{block_name}.ff", block.ff)
_register_trace_hook(handles, traces, f"{block_name}.post_feedforward_layernorm", block.post_feedforward_layernorm)
_register_trace_hook(handles, traces, "final_norm", model.final_norm)
_register_trace_hook(handles, traces, "logits", model.out_head)
return traces, handles
def _layer_sort_key(name):
block_detail_order = {
"input_layernorm": 0,
"att.q_proj": 1,
"att.k_proj": 2,
"att.v_proj": 3,
"att.q_norm": 4,
"att.k_norm": 5,
"att.o_proj": 6,
"att": 7,
"post_attention_layernorm": 8,
"pre_feedforward_layernorm": 9,
"ff.gate_proj": 10,
"ff.up_proj": 11,
"ff.down_proj": 12,
"ff": 13,
"post_feedforward_layernorm": 14,
}
if name == "embedding":
return (0, 0)
if name.startswith("block_"):
block_name, _, detail = name.partition(".")
idx = int(block_name.split("_")[1])
if not detail:
return (1, idx, -1)
return (2, idx, block_detail_order.get(detail, 100), detail)
if name == "final_norm":
return (3, 0)
if name == "logits":
return (4, 0)
return (5, name)
def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5, include_block_details=True):
ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False, include_block_details=include_block_details)
hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True, include_block_details=include_block_details)
try:
with torch.inference_mode():
ours(input_ids)
hf_model(input_ids)
finally:
for h in ours_handles + hf_handles:
h.remove()
layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key)
results = []
for name in layer_names:
ours_tensor = ours_traces.get(name)
hf_tensor = hf_traces.get(name)
if ours_tensor is None or hf_tensor is None:
results.append(
{
"name": name,
"status": "missing",
"ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape),
"hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
if ours_tensor.shape != hf_tensor.shape:
results.append(
{
"name": name,
"status": "shape_mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": None,
"mean_abs_diff": None,
}
)
continue
diff = (ours_tensor - hf_tensor).abs()
max_diff = float(diff.max().item())
mean_diff = float(diff.mean().item())
allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol)
results.append(
{
"name": name,
"status": "ok" if allclose else "mismatch",
"ours_shape": tuple(ours_tensor.shape),
"hf_shape": tuple(hf_tensor.shape),
"max_diff": max_diff,
"mean_abs_diff": mean_diff,
}
)
return results
def _format_diff_line(diff, indent=""):
if diff["status"] == "ok":
return f"{indent}[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}"
if diff["status"] == "mismatch":
return f"{indent}[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}"
if diff["status"] == "shape_mismatch":
return f"{indent}[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}"
return f"{indent}[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}"
def format_report(differences, show_block_details=True, details_for_all_blocks=False):
lines = []
top_level_diffs = [diff for diff in differences if "." not in diff["name"]]
for diff in sorted(top_level_diffs, key=lambda d: _layer_sort_key(d["name"])):
lines.append(_format_diff_line(diff))
if not show_block_details or not diff["name"].startswith("block_"):
continue
detail_prefix = f"{diff['name']}."
detail_diffs = [
other for other in differences
if other["name"].startswith(detail_prefix)
]
if not detail_diffs:
continue
has_detail_mismatch = any(other["status"] != "ok" for other in detail_diffs)
if not details_for_all_blocks and diff["status"] == "ok" and not has_detail_mismatch:
continue
if not details_for_all_blocks and diff["status"] == "ok":
continue
for other in sorted(detail_diffs, key=lambda d: _layer_sort_key(d["name"])):
lines.append(_format_diff_line(other, indent=" "))
return "\n".join(lines)
if __name__ == "__main__":
transformers_available = importlib.util.find_spec("transformers") is not None
if not transformers_available:
raise SystemExit("transformers is not installed; install it to run the debugger.")
import_notebook_defs = load_notebook_defs()
cfg = tiny_debug_config()
ours_model, hf_model = build_gemma3_pair(import_notebook_defs, cfg)
torch.manual_seed(0)
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
diffs = layerwise_differences(ours_model, hf_model, input_ids)
print(format_report(diffs))

View File

@@ -100,6 +100,7 @@ def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
torch_dtype=torch.float32,
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
rope_scaling={"rope_type": "default"},
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
)
hf_model = Gemma3ForCausalLM(hf_cfg)

View File

@@ -100,6 +100,7 @@ def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
torch_dtype=torch.float32,
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
rope_scaling={"rope_type": "default"},
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
)
hf_model = Gemma3ForCausalLM(hf_cfg)