From f571b5e493de1fc6556409c4ddca4a6b63288d82 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 19 Aug 2025 12:37:49 -0500 Subject: [PATCH] Add Gemma3 KV cache variant (#776) * Add Gemma3 KV cache variant * update --- .github/workflows/basic-tests-linux-uv.yml | 1 + .github/workflows/basic-tests-macos-uv.yml | 1 + ch05/12_gemma3/README.md | 24 +- .../standalone-gemma3-plus-kvcache.ipynb | 1296 +++++++++++++++++ ch05/12_gemma3/standalone-gemma3.ipynb | 63 +- ch05/12_gemma3/tests/test_gemma3_kv_nb.py | 113 ++ 6 files changed, 1470 insertions(+), 28 deletions(-) create mode 100644 ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb create mode 100644 ch05/12_gemma3/tests/test_gemma3_kv_nb.py diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index f02f937..b9cec98 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -55,6 +55,7 @@ jobs: pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py + pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks (uv) diff --git a/.github/workflows/basic-tests-macos-uv.yml b/.github/workflows/basic-tests-macos-uv.yml index a3f052e..e673de7 100644 --- a/.github/workflows/basic-tests-macos-uv.yml +++ b/.github/workflows/basic-tests-macos-uv.yml @@ -54,6 +54,7 @@ jobs: pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py + pytest --ruff ch05/12_gemma3/tests/test_gemma3_kv_nb.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks (uv) diff --git a/ch05/12_gemma3/README.md b/ch05/12_gemma3/README.md index 8973273..4205a78 100644 --- a/ch05/12_gemma3/README.md +++ b/ch05/12_gemma3/README.md @@ -1,6 +1,26 @@ # Gemma 3 270M From Scratch -This [standalone-gemma3.ipynb](standalone-gemma3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Gemma 3 270M. It requires about 2 GB of RAM to run. +This [standalone-gemma3.ipynb](standalone-gemma3.ipynb) Jupyter notebook in this folder contains a from-scratch implementation of Gemma 3 270M. It requires about 2 GB of RAM to run. + +The alternative [standalone-gemma3-plus-kvcache.ipynb](standalone-gemma3-plus-kvcache.ipynb) notebook adds a KV cache for better runtime performance (but adds more code complexity). To learn more about KV caching, see my [Understanding and Coding the KV Cache in LLMs from Scratch](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms) article. + +| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) | +| ----------------- | ----------------- | --------------- | ---------- | ----------------- | +| Gemma3Model 270M | Regular | Mac Mini M4 CPU | 8 | - | +| Gemma3Model 270M | Regular compiled | Mac Mini M4 CPU | 9 | - | +| Gemma3Model 270M | KV cache | Mac Mini M4 CPU | 130 | - | +| Gemma3Model 270M | KV cache compiled | Mac Mini M4 CPU | 224 | - | +| | | | | | +| Gemma3Model 270M | Regular | Mac Mini M4 GPU | 16 | - | +| Gemma3Model 270M | Regular compiled | Mac Mini M4 GPU | Error | - | +| Gemma3Model 270M | KV cache | Mac Mini M4 GPU | 23 | - | +| Gemma3Model 270M | KV cache compiled | Mac Mini M4 GPU | Error | - | +| | | | | | +| Gemma3Model 270M | Regular | Nvidia A100 GPU | 28 | 1.84 GB | +| Gemma3Model 270M | Regular compiled | Nvidia A100 GPU | 128 | 2.12 GB | +| Gemma3Model 270M | KV cache | Nvidia A100 GPU | 26 | 1.77 GB | +| Gemma3Model 270M | KV cache compiled | Nvidia A100 GPU | 99 | 2.12 GB | + Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3). @@ -10,7 +30,7 @@ Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you
-To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison)article. +To learn more about the architecture differences and read about comparisons with other architectures, see my [The Big LLM Architecture Comparison: From DeepSeek-V3 to Kimi K2: A Look At Modern LLM Architecture Design](https://magazine.sebastianraschka.com/p/the-big-llm-architecture-comparison) article. diff --git a/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb new file mode 100644 index 0000000..fbde3a5 --- /dev/null +++ b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb @@ -0,0 +1,1296 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": { + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": { + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" + }, + "source": [ + "# Gemma 3 270M From Scratch (A Standalone Notebook)" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": { + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" + }, + "source": [ + "- This notebook is purposefully minimal and focuses on the code to re-implement Gemma 3 270M in pure PyTorch without relying on other external LLM libraries\n", + "- For more information, see the official [Gemma 3 270M model card](https://huggingface.co/google/gemma-3-270m)\n", + "\n", + "- Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + "- About the code:\n", + " - all code is my own code, mapping the Gemma 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7c201adb-747e-437b-9a62-442802941e01", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 0.34.4\n", + "tokenizers version: 0.21.4\n", + "torch version: 2.8.0\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tokenizers\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "07e96fbb-8e16-4f6d-835f-c6159321280b", + "metadata": {}, + "source": [ + "- This notebook supports both the base model and the instructmodel; which model to use can be controlled via the following flag:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70a90338-624a-4706-aa55-6b4358070194", + "metadata": {}, + "outputs": [], + "source": [ + "USE_INSTRUCT_MODEL = True" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": { + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" + }, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": { + "id": "82076c21-9331-4dcd-b017-42b046cf1a60" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.gelu(x_fc1, approximate=\"tanh\") * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56715760-37e1-433e-89da-04864c139a9e", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6, bias=False):\n", + " super().__init__()\n", + " self.eps = eps\n", + " # Gemma3 stores zero-centered weights and uses (1 + weight) during forward\n", + " self.scale = nn.Parameter(torch.zeros(emb_dim))\n", + " self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None\n", + "\n", + " def forward(self, x):\n", + " # Match HF Gemma3: compute norm in float32, then scale by (1 + w)\n", + " input_dtype = x.dtype\n", + " x_f = x.float()\n", + " var = x_f.pow(2).mean(dim=-1, keepdim=True)\n", + " x_norm = x_f * torch.rsqrt(var + self.eps)\n", + " out = x_norm * (1.0 + self.scale.float())\n", + " \n", + " if self.shift is not None:\n", + " out = out + self.shift.float()\n", + " \n", + " return out.to(input_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": { + "id": "4b9a346f-5826-4083-9162-abd56afc03f0" + }, + "outputs": [], + "source": [ + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length, dtype=dtype)\n", + "\n", + " # Compute the angles\n", + " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + "\n", + " # Expand angles to match the head_dim\n", + " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", + "\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def apply_rope(x, cos, sin, offset=0):\n", + " # x: (batch_size, num_heads, seq_len, head_dim)\n", + " batch_size, num_heads, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " # Split x into first half and second half\n", + " x1 = x[..., : head_dim // 2] # First half\n", + " x2 = x[..., head_dim // 2 :] # Second half\n", + "\n", + " # Adjust sin and cos shapes\n", + " cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", + " sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " # Apply the rotary transformation\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x * cos) + (rotated * sin)\n", + "\n", + " # It's ok to use lower-precision after applying cos and sin rotation\n", + " return x_rotated.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": { + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" + }, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(\n", + " self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False,\n", + " query_pre_attn_scalar=None, dtype=None,\n", + " ):\n", + " super().__init__()\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " if head_dim is None:\n", + " assert d_in % num_heads == 0, \"`d_in` must be divisible by `num_heads` if `head_dim` is not set\"\n", + " head_dim = d_in // num_heads\n", + "\n", + " self.head_dim = head_dim\n", + " self.d_out = num_heads * head_dim\n", + "\n", + " self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)\n", + "\n", + " self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)\n", + "\n", + " if qk_norm:\n", + " self.q_norm = RMSNorm(head_dim, eps=1e-6)\n", + " self.k_norm = RMSNorm(head_dim, eps=1e-6)\n", + " else:\n", + " self.q_norm = self.k_norm = None\n", + "\n", + " if query_pre_attn_scalar is not None:\n", + " self.scaling = (query_pre_attn_scalar) ** -0.5\n", + " else:\n", + " self.scaling = (head_dim) ** -0.5\n", + "\n", + " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n", + " b, num_tokens, _ = x.shape\n", + "\n", + " # Apply projections\n", + " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n", + " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + "\n", + " # Reshape\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n", + " keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + " values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + "\n", + " # Optional Q/K normalization (applied to raw tensors)\n", + " if self.q_norm:\n", + " queries = self.q_norm(queries)\n", + " if self.k_norm:\n", + " keys_new = self.k_norm(keys_new)\n", + "\n", + " # Keep unrotated in cache; rotate after concatenation\n", + " prev_len = 0\n", + " if cache is not None:\n", + " prev_k, prev_v = cache # cached as unrotated\n", + " if prev_k is not None:\n", + " prev_len = prev_k.size(2)\n", + " keys_cat_raw = torch.cat([prev_k, keys_new], dim=2) # unrotated\n", + " values_cat_raw = torch.cat([prev_v, values_new], dim=2) # raw V\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + "\n", + " # RoPE: queries at absolute start_pos; keys with offset corrected by prev_len\n", + " queries = apply_rope(queries, cos, sin, offset=start_pos)\n", + " keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n", + "\n", + " # Scale queries\n", + " queries = queries * self.scaling\n", + "\n", + " # Update cache with unrotated keys and unscaled raw values\n", + " if cache is not None and cache[0] is not None:\n", + " next_cache = (\n", + " torch.cat([cache[0], keys_new], dim=2),\n", + " torch.cat([cache[1], values_new], dim=2),\n", + " )\n", + " else:\n", + " next_cache = (keys_new, values_new)\n", + "\n", + " # Expand K and V to match number of heads\n", + " keys = keys.repeat_interleave(self.group_size, dim=1)\n", + " values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n", + "\n", + " # Attention\n", + " attn_scores = queries @ keys.transpose(2, 3)\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", + " attn_weights = torch.softmax(attn_scores, dim=-1)\n", + "\n", + " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n", + " out = self.out_proj(context)\n", + "\n", + " return out, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": { + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" + }, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + "\n", + " def __init__(self, cfg, attn_type):\n", + " super().__init__()\n", + " self.attn_type = attn_type\n", + " self.sliding_window = cfg[\"sliding_window\"]\n", + "\n", + " self.att = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " num_kv_groups=cfg[\"n_kv_groups\"],\n", + " head_dim=cfg[\"head_dim\"],\n", + " qk_norm=cfg[\"qk_norm\"],\n", + " query_pre_attn_scalar=cfg[\"query_pre_attn_scalar\"],\n", + " dtype=cfg[\"dtype\"],\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + " self.input_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", + " self.post_attention_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", + " self.pre_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", + " self.post_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", + "\n", + " def forward(\n", + " self,\n", + " x,\n", + " mask_global,\n", + " mask_local,\n", + " cos_global,\n", + " sin_global,\n", + " cos_local,\n", + " sin_local,\n", + " start_pos=0,\n", + " cache=None\n", + " ):\n", + " # Shortcut connection for attention block\n", + " shortcut = x\n", + " x = self.input_layernorm(x)\n", + "\n", + " if self.attn_type == \"sliding_attention\":\n", + " if cache is not None and isinstance(cache, tuple):\n", + " prev_k, _ = cache\n", + " eff_kv_len = prev_k.size(2) + x.size(1)\n", + " else:\n", + " eff_kv_len = x.size(1)\n", + " # Take the last `eff_kv_len` columns so mask width equals K length\n", + " attn_mask = mask_local[..., -eff_kv_len:]\n", + " cos = cos_local\n", + " sin = sin_local\n", + " else:\n", + " attn_mask = mask_global\n", + " cos = cos_global\n", + " sin = sin_global\n", + " \n", + " x_attn, next_cache = self.att(x, attn_mask, cos, sin, start_pos=start_pos, cache=cache)\n", + " if next_cache is not None and self.attn_type == \"sliding_attention\":\n", + " k, v = next_cache\n", + " if k.size(2) > self.sliding_window:\n", + " k = k[:, :, -self.sliding_window:, :]\n", + " v = v[:, :, -self.sliding_window:, :]\n", + " next_cache = (k, v)\n", + "\n", + " x_attn = self.post_attention_layernorm(x_attn)\n", + " x = shortcut + x_attn\n", + "\n", + " # Shortcut connection for feed forward block\n", + " shortcut = x\n", + " x_ffn = self.pre_feedforward_layernorm(x)\n", + " x_ffn = self.ff(x_ffn)\n", + " x_ffn = self.post_feedforward_layernorm(x_ffn)\n", + " x = shortcut + x_ffn\n", + " return x, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", + "metadata": { + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" + }, + "outputs": [], + "source": [ + "class Gemma3Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " assert cfg[\"layer_types\"] is not None and len(cfg[\"layer_types\"]) == cfg[\"n_layers\"]\n", + "\n", + " # Main model parameters\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + "\n", + " self.blocks = nn.ModuleList([\n", + " TransformerBlock(cfg, attn_type) for attn_type in cfg[\"layer_types\"]\n", + " ])\n", + "\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-6)\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.cfg = cfg\n", + " self.current_pos = 0 # Track current position in KV cache\n", + "\n", + " # Reusuable utilities\n", + " cos_local, sin_local = compute_rope_params(\n", + " head_dim=cfg[\"head_dim\"],\n", + " theta_base=cfg[\"rope_local_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " dtype=torch.float32,\n", + " )\n", + " cos_global, sin_global = compute_rope_params(\n", + " head_dim=cfg[\"head_dim\"],\n", + " theta_base=cfg[\"rope_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " dtype=torch.float32,\n", + " )\n", + " self.register_buffer(\"cos_local\", cos_local, persistent=False)\n", + " self.register_buffer(\"sin_local\", sin_local, persistent=False)\n", + " self.register_buffer(\"cos_global\", cos_global, persistent=False)\n", + " self.register_buffer(\"sin_global\", sin_global, persistent=False)\n", + "\n", + " def _create_masks(self, cur_len, device, pos_start=0, pos_end=None):\n", + " if pos_end is None:\n", + " pos_end = cur_len\n", + " total_len = pos_end\n", + "\n", + " ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)\n", + "\n", + " # mask_global_full (future is masked: j > i)\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 0 0 0 0 0 1 1 1\n", + " # 5: 0 0 0 0 0 0 1 1\n", + " # 6: 0 0 0 0 0 0 0 1\n", + " # 7: 0 0 0 0 0 0 0 0\n", + " mask_global_full = torch.triu(ones, diagonal=1)\n", + "\n", + " # far_past (too far back is masked: i - j >= sliding_window)\n", + " # where sliding_window = 4\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 0 0 0 0 0 0 0\n", + " # 1: 0 0 0 0 0 0 0 0\n", + " # 2: 0 0 0 0 0 0 0 0\n", + " # 3: 0 0 0 0 0 0 0 0\n", + " # 4: 1 0 0 0 0 0 0 0\n", + " # 5: 1 1 0 0 0 0 0 0\n", + " # 6: 1 1 1 0 0 0 0 0\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " far_past_full = torch.triu(ones, diagonal=self.cfg[\"sliding_window\"]).T\n", + "\n", + " # Local (sliding_window) = future OR far-past\n", + " # mask_local\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 1 0 0 0 0 1 1 1\n", + " # 5: 1 1 0 0 0 0 1 1\n", + " # 6: 1 1 1 0 0 0 0 1\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " mask_local_full = mask_global_full | far_past_full\n", + "\n", + " row_slice = slice(pos_start, pos_end)\n", + " mask_global = mask_global_full[row_slice, :pos_end][None, None, :, :]\n", + " mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]\n", + " return mask_global, mask_local\n", + "\n", + "\n", + " def forward(self, input_ids, cache=None):\n", + " b, seq_len = input_ids.shape\n", + " x = self.tok_emb(input_ids) * (self.cfg[\"emb_dim\"] ** 0.5)\n", + "\n", + " if cache is not None:\n", + " pos_start = self.current_pos\n", + " pos_end = pos_start + seq_len\n", + " self.current_pos = pos_end\n", + " mask_global, mask_local = self._create_masks(\n", + " cur_len=seq_len, device=x.device, pos_start=pos_start, pos_end=pos_end\n", + " )\n", + " else:\n", + " pos_start = 0\n", + " mask_global, mask_local = self._create_masks(\n", + " cur_len=seq_len, device=x.device, pos_start=0, pos_end=seq_len\n", + " )\n", + "\n", + " for i, block in enumerate(self.blocks):\n", + " blk_cache = cache.get(i) if cache is not None else None\n", + " x, new_blk_cache = block(\n", + " x,\n", + " mask_global=mask_global,\n", + " mask_local=mask_local,\n", + " cos_global=self.cos_global,\n", + " sin_global=self.sin_global,\n", + " cos_local=self.cos_local,\n", + " sin_local=self.sin_local,\n", + " start_pos=pos_start, # position of first new token\n", + " cache=blk_cache,\n", + " )\n", + "\n", + " if cache is not None:\n", + " cache.update(i, new_blk_cache)\n", + "\n", + " # Final layernorm + projection\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", + " return logits\n", + "\n", + " def reset_kv_cache(self):\n", + " self.current_pos = 0" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": { + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" + }, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": { + "id": "caa142fa-b375-4e78-b392-2072ced666f3" + }, + "outputs": [], + "source": [ + "GEMMA3_CONFIG_270M = {\n", + " \"vocab_size\": 262_144,\n", + " \"context_length\": 32_768,\n", + " \"emb_dim\": 640,\n", + " \"n_heads\": 4,\n", + " \"n_layers\": 18,\n", + " \"hidden_dim\": 2048,\n", + " \"head_dim\": 256,\n", + " \"qk_norm\": True,\n", + " \"n_kv_groups\": 1,\n", + " \"rope_local_base\": 10_000.0,\n", + " \"rope_base\": 1_000_000.0,\n", + " \"sliding_window\": 512,\n", + " \"layer_types\": [\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\"\n", + " ],\n", + " \"dtype\": torch.bfloat16,\n", + " \"query_pre_attn_scalar\": 256,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": { + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" + }, + "outputs": [], + "source": [ + "torch.manual_seed(123)\n", + "model = Gemma3Model(GEMMA3_CONFIG_270M)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eaf86265-4e9d-4024-9ed0-99076944e304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Gemma3Model(\n", + " (tok_emb): Embedding(262144, 640)\n", + " (blocks): ModuleList(\n", + " (0-17): 18 x TransformerBlock(\n", + " (att): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=640, out_features=1024, bias=False)\n", + " (W_key): Linear(in_features=640, out_features=256, bias=False)\n", + " (W_value): Linear(in_features=640, out_features=256, bias=False)\n", + " (out_proj): Linear(in_features=1024, out_features=640, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=640, out_features=2048, bias=False)\n", + " (fc2): Linear(in_features=640, out_features=2048, bias=False)\n", + " (fc3): Linear(in_features=2048, out_features=640, bias=False)\n", + " )\n", + " (input_layernorm): RMSNorm()\n", + " (post_attention_layernorm): RMSNorm()\n", + " (pre_feedforward_layernorm): RMSNorm()\n", + " (post_feedforward_layernorm): RMSNorm()\n", + " )\n", + " )\n", + " (final_norm): RMSNorm()\n", + " (out_head): Linear(in_features=640, out_features=262144, bias=False)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b", + "metadata": {}, + "source": [ + "- A quick check that the forward pass works before continuing:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.7500, 0.1060, 0.4844, ..., 0.9414, 0.3984, -0.2324],\n", + " [-0.3438, -0.0549, 0.8984, ..., -0.2402, 0.4570, 0.8242],\n", + " [-0.2676, -0.3281, 0.4121, ..., 0.8711, -0.9648, 0.9844]]],\n", + " dtype=torch.bfloat16, grad_fn=)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(torch.tensor([1, 2, 3]).unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "outputId": "00d7e983-262e-4c65-f322-f4d999311988" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 435,870,336\n", + "\n", + "Total number of unique parameters: 268,098,176\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")\n", + "\n", + "# Account for weight tying\n", + "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", + "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "jupyter": { + "source_hidden": true + }, + "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 3.37 GB\n", + "bfloat16: 1.69 GB\n" + ] + } + ], + "source": [ + "def model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": { + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": { + "id": "c172f89f-d301-439f-b809-46169e5f5945" + }, + "source": [ + " \n", + "# 4. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": { + "id": "75166128-5899-4995-9b88-9672e135650e" + }, + "outputs": [], + "source": [ + "def load_weights_into_gemma(Gemma3Model, param_config, params):\n", + "\n", + " def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(\n", + " f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\"\n", + " )\n", + " return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))\n", + "\n", + " # Embedding weights\n", + " if \"model.embed_tokens.weight\" in params:\n", + " model.tok_emb.weight = assign(\n", + " model.tok_emb.weight,\n", + " params[\"model.embed_tokens.weight\"],\n", + " \"model.embed_tokens.weight\",\n", + " )\n", + "\n", + " # Iterate over transformer layers\n", + " for l in range(param_config[\"n_layers\"]):\n", + " block = model.blocks[l]\n", + " att = block.att\n", + " # Attention projections\n", + " att.W_query.weight = assign(\n", + " att.W_query.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_proj.weight\",\n", + " )\n", + " att.W_key.weight = assign(\n", + " att.W_key.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_proj.weight\",\n", + " )\n", + " att.W_value.weight = assign(\n", + " att.W_value.weight,\n", + " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.v_proj.weight\",\n", + " )\n", + " att.out_proj.weight = assign(\n", + " att.out_proj.weight,\n", + " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.o_proj.weight\",\n", + " )\n", + " # QK normalization weights\n", + " att.q_norm.scale = assign(\n", + " att.q_norm.scale,\n", + " params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_norm.weight\",\n", + " )\n", + " att.k_norm.scale = assign(\n", + " att.k_norm.scale,\n", + " params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_norm.weight\",\n", + " )\n", + " # Feed forward weights\n", + " block.ff.fc1.weight = assign(\n", + " block.ff.fc1.weight,\n", + " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.gate_proj.weight\",\n", + " )\n", + " block.ff.fc2.weight = assign(\n", + " block.ff.fc2.weight,\n", + " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.up_proj.weight\",\n", + " )\n", + " block.ff.fc3.weight = assign(\n", + " block.ff.fc3.weight,\n", + " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.down_proj.weight\",\n", + " )\n", + " # LayerNorm weights\n", + " block.input_layernorm.scale = assign(\n", + " block.input_layernorm.scale,\n", + " params[f\"model.layers.{l}.input_layernorm.weight\"],\n", + " f\"model.layers.{l}.input_layernorm.weight\",\n", + " )\n", + " block.post_attention_layernorm.scale = assign(\n", + " block.post_attention_layernorm.scale,\n", + " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_attention_layernorm.weight\",\n", + " )\n", + " # Pre‑ and post‑feed forward norms\n", + " pre_key = f\"model.layers.{l}.pre_feedforward_layernorm.weight\"\n", + " post_key = f\"model.layers.{l}.post_feedforward_layernorm.weight\"\n", + " if pre_key in params:\n", + " block.pre_feedforward_layernorm.scale = assign(\n", + " block.pre_feedforward_layernorm.scale,\n", + " params[pre_key],\n", + " pre_key,\n", + " )\n", + " if post_key in params:\n", + " block.post_feedforward_layernorm.scale = assign(\n", + " block.post_feedforward_layernorm.scale,\n", + " params[post_key],\n", + " post_key,\n", + " )\n", + "\n", + " # Final LayerNorm\n", + " if \"model.norm.weight\" in params:\n", + " model.final_norm.scale = assign(\n", + " model.final_norm.scale,\n", + " params[\"model.norm.weight\"],\n", + " \"model.norm.weight\",\n", + " )\n", + " # Output head\n", + " if \"lm_head.weight\" in params:\n", + " model.out_head.weight = assign(\n", + " model.out_head.weight,\n", + " params[\"lm_head.weight\"],\n", + " \"lm_head.weight\",\n", + " )\n", + " elif \"model.embed_tokens.weight\" in params:\n", + " # Weight tying: reuse the embedding weights\n", + " model.out_head.weight = assign(\n", + " model.out_head.weight,\n", + " params[\"model.embed_tokens.weight\"],\n", + " \"model.embed_tokens.weight\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9d88b23d-fc3a-4903-b97e-8ac5160d7e7d", + "metadata": {}, + "outputs": [], + "source": [ + "class KVCache:\n", + " def __init__(self, n_layers):\n", + " self.cache = [None] * n_layers\n", + "\n", + " def get(self, layer_idx):\n", + " return self.cache[layer_idx]\n", + "\n", + " def update(self, layer_idx, value):\n", + " self.cache[layer_idx] = value\n", + "\n", + " def get_all(self):\n", + " return self.cache\n", + "\n", + " def reset(self):\n", + " for i in range(len(self.cache)):\n", + " self.cache[i] = None" + ] + }, + { + "cell_type": "markdown", + "id": "430340f2-78b9-4983-b74e-8395bbd7e574", + "metadata": {}, + "source": [ + "- Please note that Google requires that you accept the Gemma 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [google/gemma-3-270m]https://huggingface.co/google/gemma-3-270m) repository to accept the terms\n", + "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", + "\n", + "\n", + "\n", + "\n", + "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run the following code if you are executing the notebook for the first time\n", + "\n", + "#from huggingface_hub import login\n", + "#login()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "9881b6995c3f49dc89e6992fd9ab660b", + "17a3174e65c54476b2e0d1faf8f011ca", + "1bbf2e62c0754d1593beb4105a7f1ac1", + "b82112e1dec645d98aa1c1ba64abcb61", + "271e2bd6a35e4a8b92de8697f7c0be5f", + "90a79523187446dfa692723b2e5833a7", + "431ffb83b8c14bf182f0430e07ea6154", + "a8f1b72a33dd4b548de23fbd95e0da18", + "25cc36132d384189acfbecc59483134b", + "bfd06423ad544218968648016e731a46", + "d029630b63ff44cf807ade428d2eb421" + ] + }, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" + }, + "outputs": [], + "source": [ + "import json\n", + "import os\n", + "from pathlib import Path\n", + "from safetensors.torch import load_file\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "\n", + "CHOOSE_MODEL = \"270m\"\n", + "\n", + "if USE_INSTRUCT_MODEL:\n", + " repo_id = f\"google/gemma-3-{CHOOSE_MODEL}-it\"\n", + "else:\n", + " repo_id = f\"google/gemma-3-{CHOOSE_MODEL}\"\n", + "\n", + "\n", + "local_dir = Path(repo_id).parts[-1]\n", + "\n", + "if CHOOSE_MODEL == \"270m\":\n", + " weights_file = hf_hub_download(\n", + " repo_id=repo_id,\n", + " filename=\"model.safetensors\",\n", + " local_dir=local_dir,\n", + " )\n", + " weights_dict = load_file(weights_file)\n", + "else:\n", + " repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n", + " index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n", + " with open(index_path, \"r\") as f:\n", + " index = json.load(f)\n", + "\n", + " weights_dict = {}\n", + " for filename in set(index[\"weight_map\"].values()):\n", + " shard_path = os.path.join(repo_dir, filename)\n", + " shard = load_file(shard_path)\n", + " weights_dict.update(shard)\n", + "\n", + "load_weights_into_gemma(model, GEMMA3_CONFIG_270M, weights_dict)\n", + "model.to(device)\n", + "del weights_dict" + ] + }, + { + "cell_type": "markdown", + "id": "6b345491-3510-4397-92d3-cd0a3fa3deee", + "metadata": {}, + "source": [ + " \n", + "# 4. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "b68ab489-48e5-471e-a814-56cda2d60f81", + "metadata": {}, + "outputs": [], + "source": [ + "from tokenizers import Tokenizer\n", + "\n", + "\n", + "class GemmaTokenizer:\n", + " def __init__(self, tokenizer_file_path: str):\n", + " tok_file = Path(tokenizer_file_path)\n", + " self._tok = Tokenizer.from_file(str(tok_file))\n", + " # Attempt to identify EOS and padding tokens\n", + " eos_token = \"\"\n", + " self.pad_token_id = eos_token\n", + " self.eos_token_id = eos_token\n", + "\n", + " def encode(self, text: str) -> list[int]:\n", + " return self._tok.encode(text).ids\n", + "\n", + " def decode(self, ids: list[int]) -> str:\n", + " return self._tok.decode(ids, skip_special_tokens=False)\n", + "\n", + "\n", + "def apply_chat_template(user_text):\n", + " return f\"user\\n{user_text}\\nmodel\\n\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_file_path = os.path.join(local_dir, \"tokenizer.json\")\n", + "if not os.path.exists(tokenizer_file_path):\n", + " try:\n", + " tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename=\"tokenizer.json\", local_dir=local_dir)\n", + " except Exception as e:\n", + " print(f\"Warning: failed to download tokenizer.json: {e}\")\n", + " tokenizer_file_path = \"tokenizer.json\"\n", + "\n", + "tokenizer = GemmaTokenizer(tokenizer_file_path=tokenizer_file_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "1946b534-e3af-431a-a222-391a60bfa892", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'user\\nGive me a short introduction to large language models.\\nmodel\\n'" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = \"Give me a short introduction to large language models.\"\n", + "prompt = apply_chat_template(\"Give me a short introduction to large language models.\")\n", + "\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "text = tokenizer.decode(input_token_ids)\n", + "text" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": { + "id": "57d07df1-4401-4792-b549-7c4cc5632323" + }, + "source": [ + " \n", + "# 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "988f55e2-0f60-4bd8-ae55-db116ff2b26d", + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally use torch.compile for an extra speed-up\n", + "# model = torch.compile(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": { + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" + }, + "outputs": [], + "source": [ + "def generate_text_basic_stream(model, token_ids, max_new_tokens, \n", + " eos_token_id=None):\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for _ in range(max_new_tokens):\n", + " out = model(token_ids)[:, -1]\n", + " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + "\n", + " if (eos_token_id is not None\n", + " and torch.all(next_token == eos_token_id)):\n", + " break\n", + "\n", + " yield next_token # New: Yield each token as it's generated\n", + " \n", + " token_ids = torch.cat([token_ids, next_token], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "56c9d0cf-25e9-4375-8d5c-368fa6911fdf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n" + ] + } + ], + "source": [ + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=150,\n", + " eos_token_id=tokenizer.encode(\"\")[-1]\n", + "):\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": { + "id": "549324d6-5c71-4147-ae21-2e67675faa3d" + }, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": { + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" + }, + "source": [ + "- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n", + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/12_gemma3/standalone-gemma3.ipynb b/ch05/12_gemma3/standalone-gemma3.ipynb index ec1ebe7..b3d0941 100644 --- a/ch05/12_gemma3/standalone-gemma3.ipynb +++ b/ch05/12_gemma3/standalone-gemma3.ipynb @@ -78,9 +78,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.33.2\n", - "tokenizers version: 0.21.2\n", - "torch version: 2.6.0\n" + "huggingface_hub version: 0.34.4\n", + "tokenizers version: 0.21.4\n", + "torch version: 2.8.0\n" ] } ], @@ -235,7 +235,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 14, "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", "metadata": { "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" @@ -320,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 15, "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", "metadata": { "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" @@ -329,7 +329,7 @@ "source": [ "class TransformerBlock(nn.Module):\n", "\n", - " def __init__(self, cfg: dict, attn_type: str):\n", + " def __init__(self, cfg, attn_type):\n", " super().__init__()\n", " self.attn_type = attn_type \n", "\n", @@ -386,7 +386,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 16, "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", "metadata": { "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" @@ -507,7 +507,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 17, "id": "caa142fa-b375-4e78-b392-2072ced666f3", "metadata": { "id": "caa142fa-b375-4e78-b392-2072ced666f3" @@ -554,7 +554,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 18, "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "metadata": { "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" @@ -567,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 19, "id": "eaf86265-4e9d-4024-9ed0-99076944e304", "metadata": {}, "outputs": [ @@ -602,7 +602,7 @@ ")" ] }, - "execution_count": 12, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -621,7 +621,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 20, "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", "metadata": {}, "outputs": [ @@ -634,7 +634,7 @@ " dtype=torch.bfloat16, grad_fn=)" ] }, - "execution_count": 13, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -645,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 21, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": { "colab": { @@ -676,7 +676,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 22, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": { "colab": { @@ -726,7 +726,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 23, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" @@ -756,7 +756,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 24, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": { "id": "75166128-5899-4995-9b88-9672e135650e" @@ -900,7 +900,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 25, "id": "7cee5292-f756-41dd-9b8d-c9b5c25d23f8", "metadata": {}, "outputs": [], @@ -913,7 +913,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 26, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": { "colab": { @@ -989,7 +989,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 27, "id": "b68ab489-48e5-471e-a814-56cda2d60f81", "metadata": {}, "outputs": [], @@ -1019,7 +1019,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 28, "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", "metadata": {}, "outputs": [], @@ -1037,7 +1037,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 29, "id": "1946b534-e3af-431a-a222-391a60bfa892", "metadata": {}, "outputs": [ @@ -1047,7 +1047,7 @@ "'user\\nGive me a short introduction to large language models.\\nmodel\\n'" ] }, - "execution_count": 22, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1075,7 +1075,18 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, + "id": "a6250333-9cf0-4f36-8e28-76be2eac1c43", + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally use torch.compile for an extra speed-up\n", + "# model = torch.compile(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1101,7 +1112,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 31, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1111,7 +1122,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n" + "Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n" ] } ], diff --git a/ch05/12_gemma3/tests/test_gemma3_kv_nb.py b/ch05/12_gemma3/tests/test_gemma3_kv_nb.py new file mode 100644 index 0000000..ca2f857 --- /dev/null +++ b/ch05/12_gemma3/tests/test_gemma3_kv_nb.py @@ -0,0 +1,113 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +transformers_installed = importlib.util.find_spec("transformers") is not None + + +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8 + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "emb_dim": 32, + "hidden_dim": 64, + "n_layers": 2, + "n_heads": 4, + "head_dim": 8, + "n_kv_groups": 1, + "qk_norm": True, # Gemma3 uses q/k RMSNorm + "dtype": torch.float32, + "rope_base": 1_000_000.0, # global RoPE base + "rope_local_base": 10_000.0, # local RoPE base (unused in these tests) + "context_length": 64, + "sliding_window": 16, + "layer_types": ["full_attention", "full_attention"], + "query_pre_attn_scalar": 256, + } + + +@torch.inference_mode() +def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports): + torch.manual_seed(123) + model = nb_imports.Gemma3Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]) + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_gemma3_base_equivalence_with_transformers(nb_imports): + from transformers import Gemma3TextConfig, Gemma3ForCausalLM + + # Tiny config so the test is fast + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_groups": 2, + "rope_base": 1_000_000.0, + "rope_local_base": 10_000.0, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + } + model = nb_imports.Gemma3Model(cfg) + + hf_cfg = Gemma3TextConfig( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_groups"], + rope_theta=cfg["rope_base"], + rope_local_base_freq=cfg["rope_local_base"], + layer_types=cfg["layer_types"], + sliding_window=cfg["sliding_window"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + query_pre_attn_scalar=cfg["query_pre_attn_scalar"], + rope_scaling={"rope_type": "default"}, + ) + hf_model = Gemma3ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]} + nb_imports.load_weights_into_gemma(model, param_config, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long) + ours_logits = model(x) + theirs_logits = hf_model(x).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5)