mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add Gemma3 KV cache variant (#776)
* Add Gemma3 KV cache variant * update
This commit is contained in:
committed by
GitHub
parent
8c1f9ccf54
commit
f571b5e493
1
.github/workflows/basic-tests-linux-uv.yml
vendored
1
.github/workflows/basic-tests-linux-uv.yml
vendored
@@ -55,6 +55,7 @@ jobs:
|
||||
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
|
||||
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
|
||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
|
||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py
|
||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks (uv)
|
||||
|
||||
1
.github/workflows/basic-tests-macos-uv.yml
vendored
1
.github/workflows/basic-tests-macos-uv.yml
vendored
@@ -54,6 +54,7 @@ jobs:
|
||||
pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py
|
||||
pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py
|
||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py
|
||||
pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py
|
||||
pytest --ruff ch06/01_main-chapter-code/tests.py
|
||||
|
||||
- name: Validate Selected Jupyter Notebooks (uv)
|
||||
|
||||
@@ -2,6 +2,26 @@
|
||||
|
||||
This [standalone-gemma3.ipynb](standalone-gemma3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Gemma 3 270M. It requires about 2 GB of RAM to run.
|
||||
|
||||
The alternative [standalone-gemma3-plus-kvcache.ipynb](standalone-gemma3-plus-kvcache.ipynb) notebook adds a KV cache for better runtime performance (but adds more code complexity). To learn more about KV caching, see my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article.
|
||||
|
||||
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|
||||
| ----------------- | ----------------- | --------------- | ---------- | ----------------- |
|
||||
| Gemma3Model 270M | Regular | Mac Mini M4 CPU | 8 | - |
|
||||
| Gemma3Model 270M | Regular compiled | Mac Mini M4 CPU | 9 | - |
|
||||
| Gemma3Model 270M | KV cache | Mac Mini M4 CPU | 130 | - |
|
||||
| Gemma3Model 270M | KV cache compiled | Mac Mini M4 CPU | 224 | - |
|
||||
| | | | | |
|
||||
| Gemma3Model 270M | Regular | Mac Mini M4 GPU | 16 | - |
|
||||
| Gemma3Model 270M | Regular compiled | Mac Mini M4 GPU | Error | - |
|
||||
| Gemma3Model 270M | KV cache | Mac Mini M4 GPU | 23 | - |
|
||||
| Gemma3Model 270M | KV cache compiled | Mac Mini M4 GPU | Error | - |
|
||||
| | | | | |
|
||||
| Gemma3Model 270M | Regular | Nvidia A100 GPU | 28 | 1.84 GB |
|
||||
| Gemma3Model 270M | Regular compiled | Nvidia A100 GPU | 128 | 2.12 GB |
|
||||
| Gemma3Model 270M | KV cache | Nvidia A100 GPU | 26 | 1.77 GB |
|
||||
| Gemma3Model 270M | KV cache compiled | Nvidia A100 GPU | 99 | 2.12 GB |
|
||||
|
||||
|
||||
Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3).
|
||||
|
||||
<br>
|
||||
@@ -10,7 +30,7 @@ Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you
|
||||
|
||||
<br>
|
||||
|
||||
To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison)article.
|
||||
To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison) article.
|
||||
|
||||
|
||||
|
||||
|
||||
1296
ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb
Normal file
1296
ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
@@ -78,9 +78,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"huggingface_hub version: 0.33.2\n",
|
||||
"tokenizers version: 0.21.2\n",
|
||||
"torch version: 2.6.0\n"
|
||||
"huggingface_hub version: 0.34.4\n",
|
||||
"tokenizers version: 0.21.4\n",
|
||||
"torch version: 2.8.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 14,
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
||||
"metadata": {
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
||||
@@ -320,7 +320,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 15,
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
||||
"metadata": {
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
||||
@@ -329,7 +329,7 @@
|
||||
"source": [
|
||||
"class TransformerBlock(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, cfg: dict, attn_type: str):\n",
|
||||
" def __init__(self, cfg, attn_type):\n",
|
||||
" super().__init__()\n",
|
||||
" self.attn_type = attn_type \n",
|
||||
"\n",
|
||||
@@ -386,7 +386,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 16,
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
||||
"metadata": {
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
||||
@@ -507,7 +507,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 17,
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
||||
"metadata": {
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
||||
@@ -554,7 +554,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 18,
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
||||
"metadata": {
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
||||
@@ -567,7 +567,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 19,
|
||||
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -602,7 +602,7 @@
|
||||
")"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"execution_count": 19,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -621,7 +621,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 20,
|
||||
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -634,7 +634,7 @@
|
||||
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -645,7 +645,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 21,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -676,7 +676,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 22,
|
||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -726,7 +726,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 23,
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
||||
"metadata": {
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
||||
@@ -756,7 +756,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 24,
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e",
|
||||
"metadata": {
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e"
|
||||
@@ -900,7 +900,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 25,
|
||||
"id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -913,7 +913,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 26,
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -989,7 +989,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 27,
|
||||
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1019,7 +1019,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 28,
|
||||
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -1037,7 +1037,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 29,
|
||||
"id": "1946b534-e3af-431a-a222-391a60bfa892",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -1047,7 +1047,7 @@
|
||||
"'<bos><start_of_turn>user\\nGive me a short introduction to large language models.<end_of_turn>\\n<start_of_turn>model\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 22,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -1075,7 +1075,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": null,
|
||||
"id": "a6250333-9cf0-4f36-8e28-76be2eac1c43",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Optionally use torch.compile for an extra speed-up\n",
|
||||
"# model = torch.compile(model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@@ -1101,7 +1112,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 31,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@@ -1111,7 +1122,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
|
||||
"Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
113
ch05/12_gemma3/tests/test_gemma3_kv_nb.py
Normal file
113
ch05/12_gemma3/tests/test_gemma3_kv_nb.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from llms_from_scratch.utils import import_definitions_from_notebook
|
||||
|
||||
|
||||
transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nb_imports():
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
|
||||
return mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_input():
|
||||
torch.manual_seed(123)
|
||||
return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_cfg_base():
|
||||
return {
|
||||
"vocab_size": 100,
|
||||
"emb_dim": 32,
|
||||
"hidden_dim": 64,
|
||||
"n_layers": 2,
|
||||
"n_heads": 4,
|
||||
"head_dim": 8,
|
||||
"n_kv_groups": 1,
|
||||
"qk_norm": True, # Gemma3 uses q/k RMSNorm
|
||||
"dtype": torch.float32,
|
||||
"rope_base": 1_000_000.0, # global RoPE base
|
||||
"rope_local_base": 10_000.0, # local RoPE base (unused in these tests)
|
||||
"context_length": 64,
|
||||
"sliding_window": 16,
|
||||
"layer_types": ["full_attention", "full_attention"],
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||
torch.manual_seed(123)
|
||||
model = nb_imports.Gemma3Model(dummy_cfg_base)
|
||||
out = model(dummy_input)
|
||||
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
|
||||
|
||||
# Tiny config so the test is fast
|
||||
cfg = {
|
||||
"vocab_size": 257,
|
||||
"context_length": 8,
|
||||
"emb_dim": 32,
|
||||
"n_heads": 4,
|
||||
"n_layers": 2,
|
||||
"hidden_dim": 64,
|
||||
"head_dim": 8,
|
||||
"qk_norm": True,
|
||||
"n_kv_groups": 2,
|
||||
"rope_base": 1_000_000.0,
|
||||
"rope_local_base": 10_000.0,
|
||||
"sliding_window": 4,
|
||||
"layer_types": ["full_attention", "full_attention"],
|
||||
"dtype": torch.float32,
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
model = nb_imports.Gemma3Model(cfg)
|
||||
|
||||
hf_cfg = Gemma3TextConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
max_position_embeddings=cfg["context_length"],
|
||||
hidden_size=cfg["emb_dim"],
|
||||
num_attention_heads=cfg["n_heads"],
|
||||
num_hidden_layers=cfg["n_layers"],
|
||||
intermediate_size=cfg["hidden_dim"],
|
||||
head_dim=cfg["head_dim"],
|
||||
num_key_value_heads=cfg["n_kv_groups"],
|
||||
rope_theta=cfg["rope_base"],
|
||||
rope_local_base_freq=cfg["rope_local_base"],
|
||||
layer_types=cfg["layer_types"],
|
||||
sliding_window=cfg["sliding_window"],
|
||||
tie_word_embeddings=False,
|
||||
attn_implementation="eager",
|
||||
torch_dtype=torch.float32,
|
||||
query_pre_attn_scalar=cfg["query_pre_attn_scalar"],
|
||||
rope_scaling={"rope_type": "default"},
|
||||
)
|
||||
hf_model = Gemma3ForCausalLM(hf_cfg)
|
||||
|
||||
hf_state = hf_model.state_dict()
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
|
||||
|
||||
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||
ours_logits = model(x)
|
||||
theirs_logits = hf_model(x).logits
|
||||
torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)
|
||||
Reference in New Issue
Block a user