Add Gemma3 KV cache variant (#776)

* Add Gemma3 KV cache variant

* update
This commit is contained in:
Sebastian Raschka
2025-08-19 12:37:49 -05:00
committed by GitHub
parent 8c1f9ccf54
commit f571b5e493
6 changed files with 1470 additions and 28 deletions

View File

@@ -55,6 +55,7 @@ jobs:
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py
pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks (uv)

View File

@@ -54,6 +54,7 @@ jobs:
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py
pytest --ruff ch06/01_main-chapter-code/tests.py
- name: Validate Selected Jupyter Notebooks (uv)

View File

@@ -2,6 +2,26 @@
This [standalone-gemma3.ipynb](standalone-gemma3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Gemma 3 270M. It requires about 2 GB of RAM to run.
The alternative [standalone-gemma3-plus-kvcache.ipynb](standalone-gemma3-plus-kvcache.ipynb) notebook adds a KV cache for better runtime performance (but adds more code complexity). To learn more about KV caching, see my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article.
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
| ----------------- | ----------------- | --------------- | ---------- | ----------------- |
| Gemma3Model 270M | Regular | Mac Mini M4 CPU | 8 | - |
| Gemma3Model 270M | Regular compiled | Mac Mini M4 CPU | 9 | - |
| Gemma3Model 270M | KV cache | Mac Mini M4 CPU | 130 | - |
| Gemma3Model 270M | KV cache compiled | Mac Mini M4 CPU | 224 | - |
| | | | | |
| Gemma3Model 270M | Regular | Mac Mini M4 GPU | 16 | - |
| Gemma3Model 270M | Regular compiled | Mac Mini M4 GPU | Error | - |
| Gemma3Model 270M | KV cache | Mac Mini M4 GPU | 23 | - |
| Gemma3Model 270M | KV cache compiled | Mac Mini M4 GPU | Error | - |
| | | | | |
| Gemma3Model 270M | Regular | Nvidia A100 GPU | 28 | 1.84 GB |
| Gemma3Model 270M | Regular compiled | Nvidia A100 GPU | 128 | 2.12 GB |
| Gemma3Model 270M | KV cache | Nvidia A100 GPU | 26 | 1.77 GB |
| Gemma3Model 270M | KV cache compiled | Nvidia A100 GPU | 99 | 2.12 GB |
Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3).
<br>
@@ -10,7 +30,7 @@ Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you
<br>
To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison)article.
To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison) article.

File diff suppressed because it is too large Load Diff

View File

@@ -78,9 +78,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.33.2\n",
"tokenizers version: 0.21.2\n",
"torch version: 2.6.0\n"
"huggingface_hub version: 0.34.4\n",
"tokenizers version: 0.21.4\n",
"torch version: 2.8.0\n"
]
}
],
@@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 14,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
@@ -320,7 +320,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 15,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
@@ -329,7 +329,7 @@
"source": [
"class TransformerBlock(nn.Module):\n",
"\n",
" def __init__(self, cfg: dict, attn_type: str):\n",
" def __init__(self, cfg, attn_type):\n",
" super().__init__()\n",
" self.attn_type = attn_type \n",
"\n",
@@ -386,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 16,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@@ -507,7 +507,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 17,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
@@ -554,7 +554,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 18,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@@ -567,7 +567,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 19,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
@@ -602,7 +602,7 @@
")"
]
},
"execution_count": 12,
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
@@ -621,7 +621,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 20,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
@@ -634,7 +634,7 @@
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 13,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@@ -645,7 +645,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 21,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {
"colab": {
@@ -676,7 +676,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 22,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {
"colab": {
@@ -726,7 +726,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 23,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@@ -756,7 +756,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 24,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
@@ -900,7 +900,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 25,
"id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8",
"metadata": {},
"outputs": [],
@@ -913,7 +913,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 26,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
@@ -989,7 +989,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 27,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
@@ -1019,7 +1019,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 28,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [],
@@ -1037,7 +1037,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 29,
"id": "1946b534-e3af-431a-a222-391a60bfa892",
"metadata": {},
"outputs": [
@@ -1047,7 +1047,7 @@
"'<bos><start_of_turn>user\\nGive me a short introduction to large language models.<end_of_turn>\\n<start_of_turn>model\\n'"
]
},
"execution_count": 22,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
@@ -1075,7 +1075,18 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"id": "a6250333-9cf0-4f36-8e28-76be2eac1c43",
"metadata": {},
"outputs": [],
"source": [
"# Optionally use torch.compile for an extra speed-up\n",
"# model = torch.compile(model)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1101,7 +1112,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 31,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1111,7 +1122,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
]
}
],

View File

@@ -0,0 +1,113 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import importlib
from pathlib import Path
import pytest
import torch
from llms_from_scratch.utils import import_definitions_from_notebook
transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
return mod
@pytest.fixture
def dummy_input():
torch.manual_seed(123)
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
@pytest.fixture
def dummy_cfg_base():
return {
"vocab_size": 100,
"emb_dim": 32,
"hidden_dim": 64,
"n_layers": 2,
"n_heads": 4,
"head_dim": 8,
"n_kv_groups": 1,
"qk_norm": True, # Gemma3 uses q/k RMSNorm
"dtype": torch.float32,
"rope_base": 1_000_000.0, # global RoPE base
"rope_local_base": 10_000.0, # local RoPE base (unused in these tests)
"context_length": 64,
"sliding_window": 16,
"layer_types": ["full_attention", "full_attention"],
"query_pre_attn_scalar": 256,
}
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
cfg = {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_groups": 2,
"rope_base": 1_000_000.0,
"rope_local_base": 10_000.0,
"sliding_window": 4,
"layer_types": ["full_attention", "full_attention"],
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
hidden_size=cfg["emb_dim"],
num_attention_heads=cfg["n_heads"],
num_hidden_layers=cfg["n_layers"],
intermediate_size=cfg["hidden_dim"],
head_dim=cfg["head_dim"],
num_key_value_heads=cfg["n_kv_groups"],
rope_theta=cfg["rope_base"],
rope_local_base_freq=cfg["rope_local_base"],
layer_types=cfg["layer_types"],
sliding_window=cfg["sliding_window"],
tie_word_embeddings=False,
attn_implementation="eager",
torch_dtype=torch.float32,
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
rope_scaling={"rope_type": "default"},
)
hf_model = Gemma3ForCausalLM(hf_cfg)
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)
theirs_logits = hf_model(x).logits
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)