mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 04:23:41 +00:00
Some gemma 3 improvements (#1000)
* some gemma 3 improvements * update url
This commit is contained in:
committed by
GitHub
parent
afc6a3da07
commit
8447d70b18
@@ -1137,20 +1137,54 @@
|
||||
" def __init__(self, tokenizer_file_path: str):\n",
|
||||
" tok_file = Path(tokenizer_file_path)\n",
|
||||
" self._tok = Tokenizer.from_file(str(tok_file))\n",
|
||||
" # Attempt to identify EOS and padding tokens\n",
|
||||
" eos_token = \"<end_of_turn>\"\n",
|
||||
" self.pad_token_id = eos_token\n",
|
||||
" self.eos_token_id = eos_token\n",
|
||||
"\n",
|
||||
" def encode(self, text: str) -> list[int]:\n",
|
||||
" return self._tok.encode(text).ids\n",
|
||||
" self.bos_token = \"<bos>\"\n",
|
||||
" self.eos_token = \"<eos>\"\n",
|
||||
" self.pad_token = \"<pad>\"\n",
|
||||
" self.start_of_turn_token = \"<start_of_turn>\"\n",
|
||||
" self.end_of_turn_token = \"<end_of_turn>\"\n",
|
||||
"\n",
|
||||
" def decode(self, ids: list[int]) -> str:\n",
|
||||
" return self._tok.decode(ids, skip_special_tokens=False)\n",
|
||||
" self.bos_token_id = self._tok.token_to_id(self.bos_token)\n",
|
||||
" self.eos_token_id = self._tok.token_to_id(self.eos_token)\n",
|
||||
" self.pad_token_id = self._tok.token_to_id(self.pad_token)\n",
|
||||
" self.start_of_turn_token_id = self._tok.token_to_id(self.start_of_turn_token)\n",
|
||||
" self.end_of_turn_token_id = self._tok.token_to_id(self.end_of_turn_token)\n",
|
||||
"\n",
|
||||
" self.add_bos_token = True\n",
|
||||
" self.add_eos_token = False\n",
|
||||
" self.clean_up_tokenization_spaces = False\n",
|
||||
"\n",
|
||||
" def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:\n",
|
||||
" return self._tok.encode(text, add_special_tokens=add_special_tokens).ids\n",
|
||||
"\n",
|
||||
" def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:\n",
|
||||
" if isinstance(ids, int):\n",
|
||||
" ids = [ids]\n",
|
||||
" return self._tok.decode(ids, skip_special_tokens=skip_special_tokens)\n",
|
||||
"\n",
|
||||
" def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):\n",
|
||||
" text = \"\"\n",
|
||||
" for message in messages:\n",
|
||||
" role = message[\"role\"]\n",
|
||||
" if role == \"assistant\":\n",
|
||||
" role = \"model\"\n",
|
||||
" content = message[\"content\"]\n",
|
||||
" text += f\"{self.start_of_turn_token}{role}\\n{content}{self.end_of_turn_token}\\n\"\n",
|
||||
"\n",
|
||||
" if add_generation_prompt:\n",
|
||||
" text += f\"{self.start_of_turn_token}model\\n\"\n",
|
||||
"\n",
|
||||
" if tokenize:\n",
|
||||
" return self.encode(text)\n",
|
||||
" return text\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def apply_chat_template(user_text):\n",
|
||||
" return f\"<start_of_turn>user\\n{user_text}<end_of_turn>\\n<start_of_turn>model\\n\""
|
||||
" return tokenizer.apply_chat_template(\n",
|
||||
" [{\"role\": \"user\", \"content\": user_text}],\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1205,7 +1239,11 @@
|
||||
],
|
||||
"source": [
|
||||
"prompt = \"Give me a short introduction to large language models.\"\n",
|
||||
"prompt = apply_chat_template(\"Give me a short introduction to large language models.\")\n",
|
||||
"prompt = tokenizer.apply_chat_template(\n",
|
||||
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input_token_ids = tokenizer.encode(prompt)\n",
|
||||
@@ -1297,7 +1335,7 @@
|
||||
" model=model,\n",
|
||||
" token_ids=input_token_ids_tensor,\n",
|
||||
" max_new_tokens=500,\n",
|
||||
" eos_token_id=tokenizer.encode(\"<end_of_turn>\")[-1]\n",
|
||||
" eos_token_id=tokenizer.end_of_turn_token_id\n",
|
||||
"):\n",
|
||||
" token_id = token.squeeze(0).tolist()\n",
|
||||
" print(\n",
|
||||
|
||||
@@ -77,9 +77,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface_hub version: 0.35.0\n",
|
||||
"tokenizers version: 0.22.1\n",
|
||||
"torch version: 2.9.0+cu130\n"
|
||||
"huggingface_hub version: 1.3.4\n",
|
||||
"tokenizers version: 0.22.2\n",
|
||||
"torch version: 2.10.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -627,9 +627,9 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[[ 0.7500, 0.1011, 0.4863, ..., 0.9414, 0.3984, -0.2285],\n",
|
||||
" [-0.3398, -0.0564, 0.9023, ..., -0.2480, 0.4551, 0.8203],\n",
|
||||
" [-0.2695, -0.3242, 0.4121, ..., 0.8672, -0.9688, 0.9844]]],\n",
|
||||
"tensor([[[ 0.7500, 0.1060, 0.4844, ..., 0.9414, 0.3984, -0.2324],\n",
|
||||
" [-0.3438, -0.0549, 0.8984, ..., -0.2402, 0.4570, 0.8242],\n",
|
||||
" [-0.2676, -0.3281, 0.4121, ..., 0.8711, -0.9648, 0.9844]]],\n",
|
||||
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
||||
]
|
||||
},
|
||||
@@ -730,20 +730,7 @@
|
||||
"metadata": {
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n",
|
||||
" Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n",
|
||||
" Minimum and Maximum cuda capability supported by this version of PyTorch is\n",
|
||||
" (8.0) - (12.0)\n",
|
||||
" \n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if torch.cuda.is_available():\n",
|
||||
" device = torch.device(\"cuda\")\n",
|
||||
@@ -1022,20 +1009,54 @@
|
||||
" def __init__(self, tokenizer_file_path: str):\n",
|
||||
" tok_file = Path(tokenizer_file_path)\n",
|
||||
" self._tok = Tokenizer.from_file(str(tok_file))\n",
|
||||
" # Attempt to identify EOS and padding tokens\n",
|
||||
" eos_token = \"<end_of_turn>\"\n",
|
||||
" self.pad_token_id = eos_token\n",
|
||||
" self.eos_token_id = eos_token\n",
|
||||
"\n",
|
||||
" def encode(self, text: str) -> list[int]:\n",
|
||||
" return self._tok.encode(text).ids\n",
|
||||
" self.bos_token = \"<bos>\"\n",
|
||||
" self.eos_token = \"<eos>\"\n",
|
||||
" self.pad_token = \"<pad>\"\n",
|
||||
" self.start_of_turn_token = \"<start_of_turn>\"\n",
|
||||
" self.end_of_turn_token = \"<end_of_turn>\"\n",
|
||||
"\n",
|
||||
" def decode(self, ids: list[int]) -> str:\n",
|
||||
" return self._tok.decode(ids, skip_special_tokens=False)\n",
|
||||
" self.bos_token_id = self._tok.token_to_id(self.bos_token)\n",
|
||||
" self.eos_token_id = self._tok.token_to_id(self.eos_token)\n",
|
||||
" self.pad_token_id = self._tok.token_to_id(self.pad_token)\n",
|
||||
" self.start_of_turn_token_id = self._tok.token_to_id(self.start_of_turn_token)\n",
|
||||
" self.end_of_turn_token_id = self._tok.token_to_id(self.end_of_turn_token)\n",
|
||||
"\n",
|
||||
" self.add_bos_token = True\n",
|
||||
" self.add_eos_token = False\n",
|
||||
" self.clean_up_tokenization_spaces = False\n",
|
||||
"\n",
|
||||
" def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:\n",
|
||||
" return self._tok.encode(text, add_special_tokens=add_special_tokens).ids\n",
|
||||
"\n",
|
||||
" def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str:\n",
|
||||
" if isinstance(ids, int):\n",
|
||||
" ids = [ids]\n",
|
||||
" return self._tok.decode(ids, skip_special_tokens=skip_special_tokens)\n",
|
||||
"\n",
|
||||
" def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False):\n",
|
||||
" text = \"\"\n",
|
||||
" for message in messages:\n",
|
||||
" role = message[\"role\"]\n",
|
||||
" if role == \"assistant\":\n",
|
||||
" role = \"model\"\n",
|
||||
" content = message[\"content\"]\n",
|
||||
" text += f\"{self.start_of_turn_token}{role}\\n{content}{self.end_of_turn_token}\\n\"\n",
|
||||
"\n",
|
||||
" if add_generation_prompt:\n",
|
||||
" text += f\"{self.start_of_turn_token}model\\n\"\n",
|
||||
"\n",
|
||||
" if tokenize:\n",
|
||||
" return self.encode(text)\n",
|
||||
" return text\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def apply_chat_template(user_text):\n",
|
||||
" return f\"<start_of_turn>user\\n{user_text}<end_of_turn>\\n<start_of_turn>model\\n\""
|
||||
" return tokenizer.apply_chat_template(\n",
|
||||
" [{\"role\": \"user\", \"content\": user_text}],\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1075,7 +1096,11 @@
|
||||
],
|
||||
"source": [
|
||||
"prompt = \"Give me a short introduction to large language models.\"\n",
|
||||
"prompt = apply_chat_template(\"Give me a short introduction to large language models.\")\n",
|
||||
"prompt = tokenizer.apply_chat_template(\n",
|
||||
" [{\"role\": \"user\", \"content\": prompt}],\n",
|
||||
" tokenize=False,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"input_token_ids = tokenizer.encode(prompt)\n",
|
||||
@@ -1107,7 +1132,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 24,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1133,7 +1158,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 25,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@@ -1143,10 +1168,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"GPU memory used: 1.04 GB\n"
|
||||
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -1162,7 +1184,7 @@
|
||||
" model=model,\n",
|
||||
" token_ids=input_token_ids_tensor,\n",
|
||||
" max_new_tokens=500,\n",
|
||||
" eos_token_id=tokenizer.encode(\"<end_of_turn>\")[-1]\n",
|
||||
" eos_token_id=tokenizer.end_of_turn_token_id\n",
|
||||
"):\n",
|
||||
" token_id = token.squeeze(0).tolist()\n",
|
||||
" print(\n",
|
||||
|
||||
161
ch05/12_gemma3/tests/gemma3-transformers-ref.ipynb
Normal file
161
ch05/12_gemma3/tests/gemma3-transformers-ref.ipynb
Normal file
@@ -0,0 +1,161 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "book-header",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<table style=\"width:100%\">\n",
|
||||
"<tr>\n",
|
||||
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||||
"<font size=\"2\">\n",
|
||||
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
|
||||
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
|
||||
"</font>\n",
|
||||
"</td>\n",
|
||||
"<td style=\"vertical-align:middle; text-align:left;\">\n",
|
||||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
|
||||
"</td>\n",
|
||||
"</tr>\n",
|
||||
"</table>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "title-cell",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Gemma 3 270M With Hugging Face Transformers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "intro-cell",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- This notebook uses the minimal `AutoTokenizer` / `AutoModelForCausalLM` workflow from the Transformers tutorials.\n",
|
||||
"- It uses the same user prompt as [standalone-gemma3.ipynb](../standalone-gemma3.ipynb): `Give me a short introduction to large language models.`"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "install-cell",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# pip install transformers sentencepiece"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "login-cell",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Uncomment and run the following code if you are executing the notebook for the first time\n",
|
||||
"\n",
|
||||
"# from huggingface_hub import login\n",
|
||||
"# login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "load-cell",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c3b335b4a1da4658b90e1ef960de8b49",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/236 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
|
||||
"\n",
|
||||
"model_id = \"google/gemma-3-270m-it\"\n",
|
||||
"prompt = \"Give me a short introduction to large language models.\"\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
|
||||
"model = AutoModelForCausalLM.from_pretrained(model_id)\n",
|
||||
"model.generation_config.do_sample = False\n",
|
||||
"model.generation_config.top_p = None\n",
|
||||
"model.generation_config.top_k = None\n",
|
||||
"model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
|
||||
"model.eval();"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "generate-cell",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [{\"role\": \"user\", \"content\": prompt}]\n",
|
||||
"\n",
|
||||
"inputs = tokenizer.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_tensors=\"pt\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"outputs = model.generate(\n",
|
||||
" **inputs,\n",
|
||||
" max_new_tokens=500,\n",
|
||||
" do_sample=False,\n",
|
||||
" num_beams=1,\n",
|
||||
" pad_token_id=tokenizer.pad_token_id,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = tokenizer.decode(\n",
|
||||
" outputs[0][inputs[\"input_ids\"].shape[-1]:],\n",
|
||||
" skip_special_tokens=True,\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
232
ch05/12_gemma3/tests/gemma3_layer_debugger.py
Normal file
232
ch05/12_gemma3/tests/gemma3_layer_debugger.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||
|
||||
try:
|
||||
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
|
||||
except ImportError:
|
||||
Gemma3ForCausalLM = None
|
||||
Gemma3TextConfig = None
|
||||
|
||||
|
||||
def tiny_debug_config():
|
||||
return {
|
||||
"vocab_size": 257,
|
||||
"context_length": 8,
|
||||
"emb_dim": 32,
|
||||
"n_heads": 4,
|
||||
"n_layers": 2,
|
||||
"hidden_dim": 64,
|
||||
"head_dim": 8,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 2,
|
||||
"rope_base": 1_000_000.0,
|
||||
"rope_local_base": 10_000.0,
|
||||
"sliding_window": 4,
|
||||
"layer_types": ["full_attention", "full_attention"],
|
||||
"dtype": torch.float32,
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
|
||||
|
||||
def _hf_config_from_dict(cfg):
|
||||
if Gemma3TextConfig is None:
|
||||
raise ImportError("transformers is required for the Gemma 3 debugger.")
|
||||
|
||||
return Gemma3TextConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_position_embeddings=cfg["context_length"],
|
||||
hidden_size=cfg["emb_dim"],
|
||||
num_attention_heads=cfg["n_heads"],
|
||||
num_hidden_layers=cfg["n_layers"],
|
||||
intermediate_size=cfg["hidden_dim"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_key_value_heads=cfg["n_kv_groups"],
|
||||
rope_theta=cfg["rope_base"],
|
||||
rope_local_base_freq=cfg["rope_local_base"],
|
||||
layer_types=cfg["layer_types"],
|
||||
sliding_window=cfg["sliding_window"],
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=cfg.get("dtype", torch.float32),
|
||||
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||
rope_scaling={"rope_type": "default"},
|
||||
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
|
||||
)
|
||||
|
||||
|
||||
def load_notebook_defs(nb_name="standalone-gemma3.ipynb"):
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
return import_definitions_from_notebook(nb_dir, nb_name)
|
||||
|
||||
|
||||
def build_gemma3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
|
||||
if Gemma3ForCausalLM is None:
|
||||
raise ImportError("transformers is required for the Gemma 3 debugger.")
|
||||
|
||||
ours = import_notebook_defs.Gemma3Model(cfg)
|
||||
|
||||
if hf_checkpoint:
|
||||
hf_model = Gemma3ForCausalLM.from_pretrained(
|
||||
hf_checkpoint,
|
||||
torch_dtype=cfg.get("dtype", torch.float32),
|
||||
attn_implementation="eager",
|
||||
)
|
||||
else:
|
||||
hf_cfg = _hf_config_from_dict(cfg)
|
||||
hf_model = Gemma3ForCausalLM(hf_cfg)
|
||||
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
import_notebook_defs.load_weights_into_gemma(ours, param_config, hf_model.state_dict())
|
||||
|
||||
ours.eval()
|
||||
hf_model.eval()
|
||||
return ours, hf_model
|
||||
|
||||
|
||||
def _attach_debug_hooks(model, is_hf):
|
||||
traces = {}
|
||||
handles = []
|
||||
|
||||
def hook(name, scale=None):
|
||||
def _record(_, __, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if scale is not None:
|
||||
output = output * scale
|
||||
traces[name] = output.detach().to(torch.float32).cpu()
|
||||
|
||||
return _record
|
||||
|
||||
if is_hf:
|
||||
core = model.model
|
||||
handles.append(core.embed_tokens.register_forward_hook(hook("embedding")))
|
||||
for idx, layer in enumerate(core.layers):
|
||||
handles.append(layer.register_forward_hook(hook(f"block_{idx}")))
|
||||
handles.append(core.norm.register_forward_hook(hook("final_norm")))
|
||||
handles.append(model.lm_head.register_forward_hook(hook("logits")))
|
||||
else:
|
||||
emb_scale = float(getattr(model, "cfg", {}).get("emb_dim", model.tok_emb.embedding_dim) ** 0.5)
|
||||
handles.append(model.tok_emb.register_forward_hook(hook("embedding", scale=emb_scale)))
|
||||
blocks = getattr(model, "blocks", None)
|
||||
if blocks is None:
|
||||
blocks = getattr(model, "trf_blocks", None)
|
||||
if blocks is None:
|
||||
raise AttributeError("Could not locate Gemma 3 blocks on the local model.")
|
||||
for idx, block in enumerate(blocks):
|
||||
handles.append(block.register_forward_hook(hook(f"block_{idx}")))
|
||||
handles.append(model.final_norm.register_forward_hook(hook("final_norm")))
|
||||
handles.append(model.out_head.register_forward_hook(hook("logits")))
|
||||
|
||||
return traces, handles
|
||||
|
||||
|
||||
def _layer_sort_key(name):
|
||||
if name == "embedding":
|
||||
return (0, 0)
|
||||
if name.startswith("block_"):
|
||||
idx = int(name.split("_")[1])
|
||||
return (1, idx)
|
||||
if name == "final_norm":
|
||||
return (2, 0)
|
||||
if name == "logits":
|
||||
return (3, 0)
|
||||
return (4, name)
|
||||
|
||||
|
||||
def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5):
|
||||
ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False)
|
||||
hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True)
|
||||
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
ours(input_ids)
|
||||
hf_model(input_ids)
|
||||
finally:
|
||||
for h in ours_handles + hf_handles:
|
||||
h.remove()
|
||||
|
||||
layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key)
|
||||
results = []
|
||||
for name in layer_names:
|
||||
ours_tensor = ours_traces.get(name)
|
||||
hf_tensor = hf_traces.get(name)
|
||||
|
||||
if ours_tensor is None or hf_tensor is None:
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "missing",
|
||||
"ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape),
|
||||
"hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape),
|
||||
"max_diff": None,
|
||||
"mean_abs_diff": None,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if ours_tensor.shape != hf_tensor.shape:
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "shape_mismatch",
|
||||
"ours_shape": tuple(ours_tensor.shape),
|
||||
"hf_shape": tuple(hf_tensor.shape),
|
||||
"max_diff": None,
|
||||
"mean_abs_diff": None,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
diff = (ours_tensor - hf_tensor).abs()
|
||||
max_diff = float(diff.max().item())
|
||||
mean_diff = float(diff.mean().item())
|
||||
allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "ok" if allclose else "mismatch",
|
||||
"ours_shape": tuple(ours_tensor.shape),
|
||||
"hf_shape": tuple(hf_tensor.shape),
|
||||
"max_diff": max_diff,
|
||||
"mean_abs_diff": mean_diff,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def format_report(differences):
|
||||
lines = []
|
||||
for diff in sorted(differences, key=lambda d: _layer_sort_key(d["name"])):
|
||||
if diff["status"] == "ok":
|
||||
lines.append(f"[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
|
||||
elif diff["status"] == "mismatch":
|
||||
lines.append(f"[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}")
|
||||
elif diff["status"] == "shape_mismatch":
|
||||
lines.append(f"[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
|
||||
else:
|
||||
lines.append(f"[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
if not transformers_available:
|
||||
raise SystemExit("transformers is not installed; install it to run the debugger.")
|
||||
|
||||
import_notebook_defs = load_notebook_defs()
|
||||
cfg = tiny_debug_config()
|
||||
|
||||
ours_model, hf_model = build_gemma3_pair(import_notebook_defs, cfg)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
|
||||
diffs = layerwise_differences(ours_model, hf_model, input_ids)
|
||||
print(format_report(diffs))
|
||||
322
ch05/12_gemma3/tests/gemma3_layer_debugger_detailed.py
Normal file
322
ch05/12_gemma3/tests/gemma3_layer_debugger_detailed.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||
|
||||
try:
|
||||
from transformers import Gemma3ForCausalLM, Gemma3TextConfig
|
||||
except ImportError:
|
||||
Gemma3ForCausalLM = None
|
||||
Gemma3TextConfig = None
|
||||
|
||||
|
||||
def tiny_debug_config():
|
||||
return {
|
||||
"vocab_size": 257,
|
||||
"context_length": 8,
|
||||
"emb_dim": 32,
|
||||
"n_heads": 4,
|
||||
"n_layers": 2,
|
||||
"hidden_dim": 64,
|
||||
"head_dim": 8,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 2,
|
||||
"rope_base": 1_000_000.0,
|
||||
"rope_local_base": 10_000.0,
|
||||
"sliding_window": 4,
|
||||
"layer_types": ["full_attention", "full_attention"],
|
||||
"dtype": torch.float32,
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
|
||||
|
||||
def _hf_config_from_dict(cfg):
|
||||
if Gemma3TextConfig is None:
|
||||
raise ImportError("transformers is required for the Gemma 3 debugger.")
|
||||
|
||||
return Gemma3TextConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_position_embeddings=cfg["context_length"],
|
||||
hidden_size=cfg["emb_dim"],
|
||||
num_attention_heads=cfg["n_heads"],
|
||||
num_hidden_layers=cfg["n_layers"],
|
||||
intermediate_size=cfg["hidden_dim"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_key_value_heads=cfg["n_kv_groups"],
|
||||
rope_theta=cfg["rope_base"],
|
||||
rope_local_base_freq=cfg["rope_local_base"],
|
||||
layer_types=cfg["layer_types"],
|
||||
sliding_window=cfg["sliding_window"],
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=cfg.get("dtype", torch.float32),
|
||||
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||
rope_scaling={"rope_type": "default"},
|
||||
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
|
||||
)
|
||||
|
||||
|
||||
def load_notebook_defs(nb_name="standalone-gemma3.ipynb"):
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
return import_definitions_from_notebook(nb_dir, nb_name)
|
||||
|
||||
|
||||
def build_gemma3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
|
||||
if Gemma3ForCausalLM is None:
|
||||
raise ImportError("transformers is required for the Gemma 3 debugger.")
|
||||
|
||||
torch.manual_seed(123)
|
||||
ours = import_notebook_defs.Gemma3Model(cfg)
|
||||
|
||||
if hf_checkpoint:
|
||||
hf_model = Gemma3ForCausalLM.from_pretrained(
|
||||
hf_checkpoint,
|
||||
torch_dtype=cfg.get("dtype", torch.float32),
|
||||
attn_implementation="eager",
|
||||
)
|
||||
else:
|
||||
hf_cfg = _hf_config_from_dict(cfg)
|
||||
hf_model = Gemma3ForCausalLM(hf_cfg)
|
||||
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
import_notebook_defs.load_weights_into_gemma(ours, param_config, hf_model.state_dict())
|
||||
|
||||
ours.eval()
|
||||
hf_model.eval()
|
||||
return ours, hf_model
|
||||
|
||||
|
||||
def _register_trace_hook(handles, traces, name, module, scale=None):
|
||||
if module is None:
|
||||
return
|
||||
|
||||
def _record(_, __, output):
|
||||
if isinstance(output, tuple):
|
||||
output = output[0]
|
||||
if scale is not None:
|
||||
output = output * scale
|
||||
traces[name] = output.detach().to(torch.float32).cpu()
|
||||
|
||||
handles.append(module.register_forward_hook(_record))
|
||||
|
||||
|
||||
def _attach_debug_hooks(model, is_hf, include_block_details=True):
|
||||
traces = {}
|
||||
handles = []
|
||||
|
||||
if is_hf:
|
||||
core = model.model
|
||||
_register_trace_hook(handles, traces, "embedding", core.embed_tokens)
|
||||
for idx, layer in enumerate(core.layers):
|
||||
block_name = f"block_{idx}"
|
||||
_register_trace_hook(handles, traces, block_name, layer)
|
||||
|
||||
if include_block_details:
|
||||
_register_trace_hook(handles, traces, f"{block_name}.input_layernorm", layer.input_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.q_proj", layer.self_attn.q_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.k_proj", layer.self_attn.k_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.v_proj", layer.self_attn.v_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.q_norm", layer.self_attn.q_norm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.k_norm", layer.self_attn.k_norm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.o_proj", layer.self_attn.o_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att", layer.self_attn)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.post_attention_layernorm", layer.post_attention_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.pre_feedforward_layernorm", layer.pre_feedforward_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.gate_proj", layer.mlp.gate_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.up_proj", layer.mlp.up_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.down_proj", layer.mlp.down_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff", layer.mlp)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.post_feedforward_layernorm", layer.post_feedforward_layernorm)
|
||||
|
||||
_register_trace_hook(handles, traces, "final_norm", core.norm)
|
||||
_register_trace_hook(handles, traces, "logits", model.lm_head)
|
||||
else:
|
||||
emb_scale = float(getattr(model, "cfg", {}).get("emb_dim", model.tok_emb.embedding_dim) ** 0.5)
|
||||
_register_trace_hook(handles, traces, "embedding", model.tok_emb, scale=emb_scale)
|
||||
blocks = getattr(model, "blocks", None)
|
||||
if blocks is None:
|
||||
blocks = getattr(model, "trf_blocks", None)
|
||||
if blocks is None:
|
||||
raise AttributeError("Could not locate Gemma 3 blocks on the local model.")
|
||||
for idx, block in enumerate(blocks):
|
||||
block_name = f"block_{idx}"
|
||||
_register_trace_hook(handles, traces, block_name, block)
|
||||
|
||||
if include_block_details:
|
||||
_register_trace_hook(handles, traces, f"{block_name}.input_layernorm", block.input_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.q_proj", block.att.W_query)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.k_proj", block.att.W_key)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.v_proj", block.att.W_value)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.q_norm", block.att.q_norm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.k_norm", block.att.k_norm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att.o_proj", block.att.out_proj)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.att", block.att)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.post_attention_layernorm", block.post_attention_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.pre_feedforward_layernorm", block.pre_feedforward_layernorm)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.gate_proj", block.ff.fc1)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.up_proj", block.ff.fc2)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff.down_proj", block.ff.fc3)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.ff", block.ff)
|
||||
_register_trace_hook(handles, traces, f"{block_name}.post_feedforward_layernorm", block.post_feedforward_layernorm)
|
||||
|
||||
_register_trace_hook(handles, traces, "final_norm", model.final_norm)
|
||||
_register_trace_hook(handles, traces, "logits", model.out_head)
|
||||
|
||||
return traces, handles
|
||||
|
||||
|
||||
def _layer_sort_key(name):
|
||||
block_detail_order = {
|
||||
"input_layernorm": 0,
|
||||
"att.q_proj": 1,
|
||||
"att.k_proj": 2,
|
||||
"att.v_proj": 3,
|
||||
"att.q_norm": 4,
|
||||
"att.k_norm": 5,
|
||||
"att.o_proj": 6,
|
||||
"att": 7,
|
||||
"post_attention_layernorm": 8,
|
||||
"pre_feedforward_layernorm": 9,
|
||||
"ff.gate_proj": 10,
|
||||
"ff.up_proj": 11,
|
||||
"ff.down_proj": 12,
|
||||
"ff": 13,
|
||||
"post_feedforward_layernorm": 14,
|
||||
}
|
||||
|
||||
if name == "embedding":
|
||||
return (0, 0)
|
||||
if name.startswith("block_"):
|
||||
block_name, _, detail = name.partition(".")
|
||||
idx = int(block_name.split("_")[1])
|
||||
if not detail:
|
||||
return (1, idx, -1)
|
||||
return (2, idx, block_detail_order.get(detail, 100), detail)
|
||||
if name == "final_norm":
|
||||
return (3, 0)
|
||||
if name == "logits":
|
||||
return (4, 0)
|
||||
return (5, name)
|
||||
|
||||
|
||||
def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5, include_block_details=True):
|
||||
ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False, include_block_details=include_block_details)
|
||||
hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True, include_block_details=include_block_details)
|
||||
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
ours(input_ids)
|
||||
hf_model(input_ids)
|
||||
finally:
|
||||
for h in ours_handles + hf_handles:
|
||||
h.remove()
|
||||
|
||||
layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key)
|
||||
results = []
|
||||
for name in layer_names:
|
||||
ours_tensor = ours_traces.get(name)
|
||||
hf_tensor = hf_traces.get(name)
|
||||
|
||||
if ours_tensor is None or hf_tensor is None:
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "missing",
|
||||
"ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape),
|
||||
"hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape),
|
||||
"max_diff": None,
|
||||
"mean_abs_diff": None,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if ours_tensor.shape != hf_tensor.shape:
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "shape_mismatch",
|
||||
"ours_shape": tuple(ours_tensor.shape),
|
||||
"hf_shape": tuple(hf_tensor.shape),
|
||||
"max_diff": None,
|
||||
"mean_abs_diff": None,
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
diff = (ours_tensor - hf_tensor).abs()
|
||||
max_diff = float(diff.max().item())
|
||||
mean_diff = float(diff.mean().item())
|
||||
allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"status": "ok" if allclose else "mismatch",
|
||||
"ours_shape": tuple(ours_tensor.shape),
|
||||
"hf_shape": tuple(hf_tensor.shape),
|
||||
"max_diff": max_diff,
|
||||
"mean_abs_diff": mean_diff,
|
||||
}
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def _format_diff_line(diff, indent=""):
|
||||
if diff["status"] == "ok":
|
||||
return f"{indent}[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}"
|
||||
if diff["status"] == "mismatch":
|
||||
return f"{indent}[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}"
|
||||
if diff["status"] == "shape_mismatch":
|
||||
return f"{indent}[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}"
|
||||
return f"{indent}[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}"
|
||||
|
||||
|
||||
def format_report(differences, show_block_details=True, details_for_all_blocks=False):
|
||||
lines = []
|
||||
top_level_diffs = [diff for diff in differences if "." not in diff["name"]]
|
||||
|
||||
for diff in sorted(top_level_diffs, key=lambda d: _layer_sort_key(d["name"])):
|
||||
lines.append(_format_diff_line(diff))
|
||||
|
||||
if not show_block_details or not diff["name"].startswith("block_"):
|
||||
continue
|
||||
|
||||
detail_prefix = f"{diff['name']}."
|
||||
detail_diffs = [
|
||||
other for other in differences
|
||||
if other["name"].startswith(detail_prefix)
|
||||
]
|
||||
if not detail_diffs:
|
||||
continue
|
||||
|
||||
has_detail_mismatch = any(other["status"] != "ok" for other in detail_diffs)
|
||||
if not details_for_all_blocks and diff["status"] == "ok" and not has_detail_mismatch:
|
||||
continue
|
||||
if not details_for_all_blocks and diff["status"] == "ok":
|
||||
continue
|
||||
|
||||
for other in sorted(detail_diffs, key=lambda d: _layer_sort_key(d["name"])):
|
||||
lines.append(_format_diff_line(other, indent=" "))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
transformers_available = importlib.util.find_spec("transformers") is not None
|
||||
if not transformers_available:
|
||||
raise SystemExit("transformers is not installed; install it to run the debugger.")
|
||||
|
||||
import_notebook_defs = load_notebook_defs()
|
||||
cfg = tiny_debug_config()
|
||||
|
||||
ours_model, hf_model = build_gemma3_pair(import_notebook_defs, cfg)
|
||||
torch.manual_seed(0)
|
||||
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
|
||||
diffs = layerwise_differences(ours_model, hf_model, input_ids)
|
||||
print(format_report(diffs))
|
||||
@@ -100,6 +100,7 @@ def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
|
||||
torch_dtype=torch.float32,
|
||||
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||
rope_scaling={"rope_type": "default"},
|
||||
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
|
||||
)
|
||||
hf_model = Gemma3ForCausalLM(hf_cfg)
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
|
||||
torch_dtype=torch.float32,
|
||||
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||
rope_scaling={"rope_type": "default"},
|
||||
ignore_keys_at_rope_validation={"full_attention", "sliding_attention"},
|
||||
)
|
||||
hf_model = Gemma3ForCausalLM(hf_cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user