Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

3
.gitignore vendored
View File

@@ -1,3 +1,6 @@
# Reports
reports/
# Configs and keys
.chainlit
ch05/07_gpt_to_llama/config.json

View File

@@ -79,6 +79,7 @@
"id": "2417139b-2357-44d2-bd67-23f5d7f52ae7",
"metadata": {},
"source": [
" \n",
"## 2.1 Understanding word embeddings"
]
},
@@ -128,6 +129,7 @@
"id": "eddbb984-8d23-40c5-bbfa-c3c379e7eec3",
"metadata": {},
"source": [
" \n",
"## 2.2 Tokenizing text"
]
},
@@ -445,6 +447,7 @@
"id": "0b5ce8fe-3a07-4f2a-90f1-a0321ce3a231",
"metadata": {},
"source": [
" \n",
"## 2.3 Converting tokens into token IDs"
]
},
@@ -738,6 +741,7 @@
"id": "4b821ef8-4d53-43b6-a2b2-aef808c343c7",
"metadata": {},
"source": [
" \n",
"## 2.4 Adding special context tokens"
]
},
@@ -1013,6 +1017,7 @@
"id": "5c4ba34b-170f-4e71-939b-77aabb776f14",
"metadata": {},
"source": [
" \n",
"## 2.5 BytePair encoding"
]
},
@@ -1528,6 +1533,7 @@
"id": "2cd2fcda-2fda-4aa8-8bc8-de1e496f9db1",
"metadata": {},
"source": [
" \n",
"## 2.7 Creating token embeddings"
]
},
@@ -1715,6 +1721,7 @@
"id": "c393d270-b950-4bc8-99ea-97d74f2ea0f6",
"metadata": {},
"source": [
" \n",
"## 2.8 Encoding word positions"
]
},
@@ -1945,7 +1952,8 @@
"id": "63230f2e-258f-4497-9e2e-8deee4530364",
"metadata": {},
"source": [
"# Summary and takeaways"
" \n",
"## Summary and takeaways"
]
},
{
@@ -1977,7 +1985,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -63,7 +63,8 @@
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {},
"source": [
"# Exercise 2.1"
" \n",
"## Exercise 2.1"
]
},
{
@@ -273,7 +274,8 @@
"id": "29e5034a-95ed-46d8-9972-589354dc9fd4",
"metadata": {},
"source": [
"# Exercise 2.2"
" \n",
"## Exercise 2.2"
]
},
{
@@ -407,7 +409,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -54,7 +54,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Using BPE from `tiktoken`"
"## 1. Using BPE from `tiktoken`"
]
},
{
@@ -157,7 +157,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Using the original BPE implementation used in GPT-2"
"## 2. Using the original BPE implementation used in GPT-2"
]
},
{
@@ -247,7 +247,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Using the BPE via Hugging Face transformers"
"## 3. Using the BPE via Hugging Face transformers"
]
},
{
@@ -355,7 +355,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Using my own from-scratch BPE tokenizer"
"## 4. Using my own from-scratch BPE tokenizer"
]
},
{
@@ -449,7 +449,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## A quick performance benchmark"
"## 5. A quick performance benchmark"
]
},
{
@@ -468,7 +468,8 @@
"id": "9c0ae9f0-47a1-4e7f-a210-e1d2721f4d1e",
"metadata": {},
"source": [
"### Original OpenAI GPT-2 tokenizer"
"&nbsp;\n",
"### 5.1 Original OpenAI GPT-2 tokenizer"
]
},
{
@@ -494,7 +495,8 @@
"id": "ef2ce3f3-1f81-47ce-b563-99fe2c7a1e90",
"metadata": {},
"source": [
"### Tiktoken OpenAI GPT-2 tokenizer"
"&nbsp;\n",
"### 5.2 Tiktoken OpenAI GPT-2 tokenizer"
]
},
{
@@ -520,7 +522,8 @@
"id": "0c748de8-273e-42df-b078-3a510106da60",
"metadata": {},
"source": [
"### Hugging Face OpenAI GPT-2 tokenizer"
"&nbsp;\n",
"### 5.3 Hugging Face OpenAI GPT-2 tokenizer"
]
},
{
@@ -614,7 +617,8 @@
"id": "91ac2876-f36e-498c-bd75-8597a39f2d4b",
"metadata": {},
"source": [
"### My own GPT-2 tokenizer (for educational purposes)"
"&nbsp;\n",
"### 5.4 My own GPT-2 tokenizer (for educational purposes)"
]
},
{
@@ -652,7 +656,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -85,6 +85,7 @@
"id": "ecc4dcee-34ea-4c05-9085-2f8887f70363",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.1 The problem with modeling long sequences"
]
},
@@ -127,6 +128,7 @@
"id": "3602c585-b87a-41c7-a324-c5e8298849df",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.2 Capturing data dependencies with attention mechanisms"
]
},
@@ -168,6 +170,7 @@
"id": "5efe05ff-b441-408e-8d66-cde4eb3397e3",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.3 Attending to different parts of the input with self-attention"
]
},
@@ -176,6 +179,7 @@
"id": "6d9af516-7c37-4400-ab53-34936d5495a9",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.3.1 A simple self-attention mechanism without trainable weights"
]
},
@@ -216,7 +220,7 @@
"id": "ff856c58-8382-44c7-827f-798040e6e697",
"metadata": {},
"source": [
"- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**\n"
"- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**"
]
},
{
@@ -503,6 +507,7 @@
"id": "5a454262-40eb-430e-9ca4-e43fb8d6cd89",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.3.2 Computing attention weights for all input tokens"
]
},
@@ -739,6 +744,7 @@
"id": "a303b6fb-9f7e-42bb-9fdb-2adabf0a6525",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.4 Implementing self-attention with trainable weights"
]
},
@@ -763,6 +769,7 @@
"id": "2b90a77e-d746-4704-9354-1ddad86e6298",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.4.1 Computing the attention weights step by step"
]
},
@@ -1046,6 +1053,7 @@
"id": "9d7b2907-e448-473e-b46c-77735a7281d8",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.4.2 Implementing a compact SelfAttention class"
]
},
@@ -1179,6 +1187,7 @@
"id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.5 Hiding future words with causal attention"
]
},
@@ -1203,6 +1212,7 @@
"id": "82f405de-cd86-4e72-8f3c-9ea0354946ba",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.5.1 Applying a causal attention mask"
]
},
@@ -1455,6 +1465,7 @@
"id": "7636fc5f-6bc6-461e-ac6a-99ec8e3c0912",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.5.2 Masking additional attention weights with dropout"
]
},
@@ -1554,6 +1565,7 @@
"id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.5.3 Implementing a compact causal self-attention class"
]
},
@@ -1679,6 +1691,7 @@
"id": "c8bef90f-cfd4-4289-b0e8-6a00dc9be44c",
"metadata": {},
"source": [
"&nbsp;\n",
"## 3.6 Extending single-head attention to multi-head attention"
]
},
@@ -1687,6 +1700,7 @@
"id": "11697757-9198-4a1c-9cee-f450d8bbd3b9",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.6.1 Stacking multiple single-head attention layers"
]
},
@@ -1776,6 +1790,7 @@
"id": "6836b5da-ef82-4b4c-bda1-72a462e48d4e",
"metadata": {},
"source": [
"&nbsp;\n",
"### 3.6.2 Implementing multi-head attention with weight splits"
]
},
@@ -2032,7 +2047,8 @@
"id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
"metadata": {},
"source": [
"# Summary and takeaways"
"&nbsp;\n",
"## Summary and takeaways"
]
},
{
@@ -2061,7 +2077,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -54,7 +54,8 @@
"id": "33dfa199-9aee-41d4-a64b-7e3811b9a616",
"metadata": {},
"source": [
"# Exercise 3.1"
"&nbsp;\n",
"## Exercise 3.1"
]
},
{
@@ -209,7 +210,8 @@
"id": "33543edb-46b5-4b01-8704-f7f101230544",
"metadata": {},
"source": [
"# Exercise 3.2"
"&nbsp;\n",
"## Exercise 3.2"
]
},
{
@@ -266,7 +268,8 @@
"id": "92bdabcb-06cf-4576-b810-d883bbd313ba",
"metadata": {},
"source": [
"# Exercise 3.3"
"&nbsp;\n",
"## Exercise 3.3"
]
},
{
@@ -339,7 +342,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -117,7 +117,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 1) CausalAttention MHA wrapper class from chapter 3"
"## 1. CausalAttention MHA wrapper class from chapter 3"
]
},
{
@@ -208,7 +208,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 2) The multi-head attention class from chapter 3"
"## 2. The multi-head attention class from chapter 3"
]
},
{
@@ -311,7 +311,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 3) An alternative multi-head attention with combined weights"
"## 3. An alternative multi-head attention with combined weights"
]
},
{
@@ -435,7 +435,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 4) Multi-head attention with Einsum\n",
"## 4. Multi-head attention with Einsum\n",
"\n",
"- Implementing multi-head attention using Einstein summation via [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html)"
]
@@ -567,7 +567,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
"## 5. Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
]
},
{
@@ -676,7 +676,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
"## 6. PyTorch's scaled dot product attention without FlashAttention\n",
"\n",
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
]
@@ -785,7 +785,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 7) Using PyTorch's torch.nn.MultiheadAttention"
"## 7. Using PyTorch's torch.nn.MultiheadAttention"
]
},
{
@@ -883,7 +883,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 8) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
"## 8. Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
]
},
{
@@ -948,7 +948,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## 9) Using PyTorch's FlexAttention\n",
"## 9. Using PyTorch's FlexAttention\n",
"\n",
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
"- FlexAttention caveat: It currently doesn't support dropout\n",
@@ -1108,7 +1108,18 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Quick speed comparison (M3 Macbook Air CPU)"
"## 10. Quick speed comparisons"
]
},
{
"cell_type": "markdown",
"id": "992e28f4-a6b9-4dd3-9705-30d0b9f4b5f0",
"metadata": {},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"### 10.1 Speed comparisons on M3 Macbook Air CPU"
]
},
{
@@ -1361,7 +1372,7 @@
"<br>\n",
"&nbsp;\n",
"\n",
"## Quick speed comparison (Nvidia A100 GPU)"
"### 10.2 Quick speed comparison on Nvidia A100 GPU"
]
},
{
@@ -1643,7 +1654,18 @@
"&nbsp;\n",
"\n",
"\n",
"# Visualizations"
"## 11. Visualizations"
]
},
{
"cell_type": "markdown",
"id": "e6baf5ce-45ac-4e26-9523-5c32b82dc784",
"metadata": {},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"### 11.1 Visualization utility functions"
]
},
{
@@ -1752,7 +1774,8 @@
"id": "4df834dc"
},
"source": [
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
"&nbsp;\n",
"### 11.2 Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
]
},
{
@@ -1834,7 +1857,7 @@
"&nbsp;\n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
"### 11.3 Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
]
},
{
@@ -1920,7 +1943,7 @@
"&nbsp;\n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
"### 11.4 Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
]
},
{

View File

@@ -7,7 +7,7 @@ from llms_from_scratch.utils import import_definitions_from_notebook
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "mha-implementations.ipynb")
return mod
@@ -31,12 +31,12 @@ def copy_weights(from_mha, to_mha):
(1024, 512, 2, 4, 8, 789), # d_in > d_out
],
)
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, nb_imports):
def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, import_notebook_defs):
torch.manual_seed(seed)
x = torch.randn(batch, seq_len, d_in)
mha_linear = nb_imports.Ch03_MHA(
mha_linear = import_notebook_defs.Ch03_MHA(
d_in=d_in,
d_out=d_out,
context_length=seq_len,
@@ -45,7 +45,7 @@ def test_mha_einsum_matches_ch03(d_in, d_out, batch, seq_len, num_heads, seed, n
qkv_bias=False,
).eval()
mha_einsum = nb_imports.MHAEinsum(
mha_einsum = import_notebook_defs.MHAEinsum(
d_in=d_in,
d_out=d_out,
context_length=seq_len,

View File

@@ -73,6 +73,7 @@
"id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.1 Coding an LLM architecture"
]
},
@@ -323,6 +324,7 @@
"id": "f8332a00-98da-4eb4-b882-922776a89917",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.2 Normalizing activations with layer normalization"
]
},
@@ -606,6 +608,7 @@
"id": "11190e7d-8c29-4115-824a-e03702f9dd54",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.3 Implementing a feed forward network with GELU activations"
]
},
@@ -789,6 +792,7 @@
"id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.4 Adding shortcut connections"
]
},
@@ -950,6 +954,7 @@
"id": "cae578ca-e564-42cf-8635-a2267047cdff",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.5 Connecting attention and linear layers in a transformer block"
]
},
@@ -1068,6 +1073,7 @@
"id": "46618527-15ac-4c32-ad85-6cfea83e006e",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.6 Coding the GPT model"
]
},
@@ -1332,6 +1338,7 @@
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
"metadata": {},
"source": [
"&nbsp;\n",
"## 4.7 Generating text"
]
},
@@ -1519,11 +1526,20 @@
"id": "a35278b6-9e5c-480f-83e5-011a1173648f",
"metadata": {},
"source": [
"&nbsp;\n",
"## Summary and takeaways\n",
"\n",
"- See the [./gpt.py](./gpt.py) script, a self-contained script containing the GPT model we implement in this Jupyter notebook\n",
"- You can find the exercise solutions in [./exercise-solutions.ipynb](./exercise-solutions.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4821ac83-ef84-42c4-a327-32bf2820a8e5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -53,7 +53,8 @@
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"# Exercise 4.1: Parameters in the feed forward versus attention module"
"&nbsp;\n",
"## Exercise 4.1: Parameters in the feed forward versus attention module"
]
},
{
@@ -182,7 +183,8 @@
"id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d",
"metadata": {},
"source": [
"# Exercise 4.2: Initialize larger GPT models"
"&nbsp;\n",
"## Exercise 4.2: Initialize larger GPT models"
]
},
{
@@ -329,7 +331,8 @@
"id": "f5f2306e-5dc8-498e-92ee-70ae7ec37ac1",
"metadata": {},
"source": [
"# Exercise 4.3: Using separate dropout parameters"
"&nbsp;\n",
"## Exercise 4.3: Using separate dropout parameters"
]
},
{
@@ -451,7 +454,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -18,13 +18,13 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
head_dim = math.ceil(emb_dim / n_heads)
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
return per_layer * n_layers
@@ -58,7 +58,7 @@ def main():
n_kv_heads_mha = cfg["n_heads"]
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -68,7 +68,7 @@ def main():
bytes_per_elem,
)
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -91,8 +91,8 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")

View File

@@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
# Import from ./memory_estimator.py
from memory_estimator_gqa import kv_bytes_total, DTYPE_BYTES
from memory_estimator_gqa import calc_kv_bytes_total, DTYPE_BYTES
def bytes_convert(n):
@@ -36,7 +36,7 @@ def plot_abs_kv_vs_context_multi_groups():
mha_gb = []
for L in context_lengths:
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_heads, # MHA: n_kv_heads = n_heads
n_layers, bytes_per_elem
@@ -52,7 +52,7 @@ def plot_abs_kv_vs_context_multi_groups():
n_kv_heads = n_heads // g
gqa_gb = []
for L in context_lengths:
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
batch_size, L, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem
)

View File

@@ -17,20 +17,20 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
def calc_kv_bytes_total(batch, context_length, emb_dim, n_heads,
n_kv_heads, n_layers, bytes_per_elem):
# Generic KV-cache: per-head dim is embed_dim / n_heads, times 2 for K and V
head_dim = math.ceil(emb_dim / n_heads)
per_layer = batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
return per_layer * n_layers
def mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
def calc_mla_bytes_total(batch, context_length, n_layers, latent_dim, bytes_per_elem):
# Simple MLA (per-token compressed latent)
# bytes ≈ batch × seqlen × n_layers × latent_dim × bytes_per_elem
return batch * context_length * n_layers * latent_dim * bytes_per_elem
@@ -66,7 +66,7 @@ def main():
n_kv_heads_mha = cfg["n_heads"]
n_kv_heads_gqa = cfg["n_heads"] // cfg["n_kv_groups"]
total_mha = kv_bytes_total(
total_mha = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -76,7 +76,7 @@ def main():
bytes_per_elem,
)
total_gqa = kv_bytes_total(
total_gqa = calc_kv_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["emb_dim"],
@@ -86,7 +86,7 @@ def main():
bytes_per_elem,
)
total_mla = mla_bytes_total(
total_mla = calc_mla_bytes_total(
args.batch_size,
cfg["context_length"],
cfg["n_layers"],
@@ -110,9 +110,9 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA total KV cache : {bytes_convert(total_mha)}")
print(f"GQA total KV cache : {bytes_convert(total_gqa)}")
print(f"MLA total KV cache : {bytes_convert(total_mla)}")
print(f"MHA total KV cache : {convert_bytes(total_mha)}")
print(f"GQA total KV cache : {convert_bytes(total_gqa)}")
print(f"MLA total KV cache : {convert_bytes(total_mla)}")
print(f"Ratio (MHA / GQA) : {ratio:,.2f}x")
print(f"Savings (GQA vs MHA): {savings*100:,.2f}%")
print(f"Ratio (MHA / MLA) : {ratio_mha_mla:,.2f}x")

View File

@@ -15,18 +15,18 @@ DTYPE_BYTES = {
}
def bytes_to_gb(n_bytes):
def convert_bytes_to_gb(n_bytes):
return n_bytes / (1000. ** 3)
def kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_heads,
n_layers, bytes_per_elem):
head_dim = emb_dim / n_heads
per_layer = batch * context_length * head_dim * n_heads * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
def calc_kv_bytes_total_mla(batch, context_length, n_layers, latent_dim, bytes_per_elem):
return batch * context_length * n_layers * latent_dim * bytes_per_elem
@@ -45,27 +45,27 @@ def plot_abs_kv_vs_context_multiple():
mha_gb = []
for L in context_lengths:
total_mha = kv_bytes_total_mha(
total_mha = calc_kv_bytes_total_mha(
batch_size, L, emb_dim, n_heads, n_layers, bytes_per_elem
)
mha_gb.append(bytes_to_gb(total_mha))
mha_gb.append(convert_bytes_to_gb(total_mha))
latent_dims = [1024, 512, 256, 64]
plt.figure()
plt.plot(context_lengths, mha_gb, marker="o", label="MHA (KV total)")
L_ref = context_lengths[-1]
total_mha_ref = kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
total_mha_ref = calc_kv_bytes_total_mha(batch_size, L_ref, emb_dim, n_heads, n_layers, bytes_per_elem)
for latent_dim in latent_dims:
mla_gb = []
for L in context_lengths:
total_mla = kv_bytes_total_mla(
total_mla = calc_kv_bytes_total_mla(
batch_size, L, n_layers, latent_dim, bytes_per_elem
)
mla_gb.append(bytes_to_gb(total_mla))
mla_gb.append(convert_bytes_to_gb(total_mla))
total_mla_ref = kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
total_mla_ref = calc_kv_bytes_total_mla(batch_size, L_ref, n_layers, latent_dim, bytes_per_elem)
comp = total_mha_ref / total_mla_ref if total_mla_ref != 0 else float("inf")
plt.plot(context_lengths, mla_gb, marker="o",

View File

@@ -17,12 +17,12 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
def kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
def calc_kv_bytes_per_layer(batch, context_length, head_dim, n_kv_heads, bytes_per_elem):
# KV = batch * tokens * head_dim * n_kv_heads * 2 (K,V) * bytes
return batch * context_length * head_dim * n_kv_heads * 2 * bytes_per_elem
@@ -64,10 +64,10 @@ def estimate_totals(context_length, sliding_window_size, emb_dim, n_heads, n_lay
L = context_length
# Per-layer costs
per_mha_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_full = kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_swa = kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_full = calc_kv_bytes_per_layer(batch_size, L, head_dim, n_kv_heads_gqa, bytes_per_elem)
per_mha_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_mha, bytes_per_elem)
per_gqa_swa = calc_kv_bytes_per_layer(batch_size, eff_W, head_dim, n_kv_heads_gqa, bytes_per_elem)
# Totals
total_mha_allfull = per_mha_full * n_layers
@@ -140,10 +140,10 @@ def main():
print()
print("==== KV-cache totals across all layers ====")
print(f"MHA KV total : {bytes_convert(res['total_mha_allfull'])}")
print(f"GQA KV total : {bytes_convert(res['total_gqa_allfull'])}")
print(f"MHA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_mha'])}")
print(f"GQA + SWA (ratio {args.swa_ratio}) : {bytes_convert(res['total_mixed_gqa'])}")
print(f"MHA KV total : {convert_bytes(res['total_mha_allfull'])}")
print(f"GQA KV total : {convert_bytes(res['total_gqa_allfull'])}")
print(f"MHA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_mha'])}")
print(f"GQA + SWA (ratio {args.swa_ratio}) : {convert_bytes(res['total_mixed_gqa'])}")
print()

View File

@@ -24,7 +24,7 @@ DTYPE_BYTES = {
}
def bytes_to_gb(n_bytes):
def convert_bytes_to_gb(n_bytes):
return n_bytes / (1000.0 ** 3)
@@ -39,22 +39,22 @@ def parse_ratio(ratio_str):
raise ValueError("--swa_ratio must be in the form 'a:b' with nonnegative integers and a+b>0")
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem):
# For MHA, n_kv_heads = n_heads, which cancels out:
# total = B * L * E * 2 (K,V) * bytes * n_layers
return batch * context_length * emb_dim * 2 * bytes_per_elem * n_layers
def kv_bytes_total_gqa(
def calc_kv_bytes_total_gqa(
batch, context_length, emb_dim, n_layers, bytes_per_elem, n_kv_groups
):
# For GQA, n_kv_heads = n_heads / n_kv_groups
# => scale the MHA total by 1 / n_kv_groups
base = kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
base = calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem)
return base / n_kv_groups
def kv_bytes_total_mha_swa(
def calc_kv_bytes_total_mha_swa(
batch, context_length, emb_dim, n_layers, bytes_per_elem, window, swa_ratio
):
# Split layers into SWA vs Full
@@ -63,16 +63,16 @@ def kv_bytes_total_mha_swa(
n_swa_layers = int(round(n_layers * (a / total_blocks)))
n_full_layers = n_layers - n_swa_layers
total_full = kv_bytes_total_mha(
total_full = calc_kv_bytes_total_mha(
batch, context_length, emb_dim, n_full_layers, bytes_per_elem
)
total_swa = kv_bytes_total_mha(
total_swa = calc_kv_bytes_total_mha(
batch, window, emb_dim, n_swa_layers, bytes_per_elem
)
return total_full + total_swa
def kv_bytes_total_gqa_swa(
def calc_kv_bytes_total_gqa_swa(
batch,
context_length,
emb_dim,
@@ -87,7 +87,7 @@ def kv_bytes_total_gqa_swa(
n_swa_layers = int(round(n_layers * (a / total_blocks)))
n_full_layers = n_layers - n_swa_layers
total_full = kv_bytes_total_gqa(
total_full = calc_kv_bytes_total_gqa(
batch,
context_length,
emb_dim,
@@ -95,7 +95,7 @@ def kv_bytes_total_gqa_swa(
bytes_per_elem,
n_kv_groups,
)
total_swa = kv_bytes_total_gqa(
total_swa = calc_kv_bytes_total_gqa(
batch, window, emb_dim, n_swa_layers, bytes_per_elem, n_kv_groups
)
return total_full + total_swa
@@ -144,10 +144,10 @@ def main():
] = []
for L in context_lengths:
total_mha = kv_bytes_total_mha(
total_mha = calc_kv_bytes_total_mha(
batch_size, L, emb_dim, n_layers, bytes_per_elem
)
total_mha_swa = kv_bytes_total_mha_swa(
total_mha_swa = calc_kv_bytes_total_mha_swa(
batch_size,
L,
emb_dim,
@@ -156,16 +156,16 @@ def main():
window=args.sliding_window_size,
swa_ratio=args.swa_ratio,
)
series["MHA (KV total)"].append(bytes_to_gb(total_mha))
series["MHA (KV total)"].append(convert_bytes_to_gb(total_mha))
series[
f"SWA on MHA (ratio {args.swa_ratio}, W={args.sliding_window_size})"
].append(bytes_to_gb(total_mha_swa))
].append(convert_bytes_to_gb(total_mha_swa))
if valid_g4:
total_gqa = kv_bytes_total_gqa(
total_gqa = calc_kv_bytes_total_gqa(
batch_size, L, emb_dim, n_layers, bytes_per_elem, n_kv_groups=kv_groups
)
total_gqa_swa = kv_bytes_total_gqa_swa(
total_gqa_swa = calc_kv_bytes_total_gqa_swa(
batch_size,
L,
emb_dim,
@@ -175,10 +175,10 @@ def main():
window=args.sliding_window_size,
swa_ratio=args.swa_ratio,
)
series["GQA kv_groups=4 (full)"].append(bytes_to_gb(total_gqa))
series["GQA kv_groups=4 (full)"].append(convert_bytes_to_gb(total_gqa))
series[
f"SWA on GQA kv_groups=4 (ratio {args.swa_ratio}, W={args.sliding_window_size})"
].append(bytes_to_gb(total_gqa_swa))
].append(convert_bytes_to_gb(total_gqa_swa))
plt.figure(figsize=(10, 5))
x = np.array(context_lengths, dtype=float)

View File

@@ -14,7 +14,7 @@ DTYPE_BYTES = {
}
def bytes_convert(n):
def convert_bytes(n):
gb = n / (1000 ** 3)
return f"{gb:,.2f} GB"
@@ -28,19 +28,19 @@ def get_num_param_matrices(ffn_type):
raise ValueError("--ffn_type must be 'gelu' or 'swiglu'")
def ffn_params(emb_dim, hidden_dim, ffn_type):
def calc_ffn_params(emb_dim, hidden_dim, ffn_type):
return get_num_param_matrices(ffn_type) * emb_dim * hidden_dim
def router_params(emb_dim, num_experts):
def calc_router_params(emb_dim, num_experts):
return emb_dim * num_experts
def estimate_params_and_hidden(
emb_dim, hidden_dim, ffn_type, num_experts, match_dense=False
):
P_dense = ffn_params(emb_dim, hidden_dim, ffn_type)
R = router_params(emb_dim, num_experts)
P_dense = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
R = calc_router_params(emb_dim, num_experts)
if match_dense:
num_param_matrices = get_num_param_matrices(ffn_type)
@@ -52,7 +52,7 @@ def estimate_params_and_hidden(
else:
moe_hidden_dim = hidden_dim
per_expert_params = ffn_params(emb_dim, moe_hidden_dim, ffn_type)
per_expert_params = calc_ffn_params(emb_dim, moe_hidden_dim, ffn_type)
moe_total = num_experts * per_expert_params + R
return {
@@ -110,15 +110,15 @@ def main():
print("==== Model weights (parameters) ====")
print(f"{'Dense FFN params':23}: {res['dense_params']:,} "
f"({bytes_convert(res['dense_params'] * bytes_per_elem)})")
f"({convert_bytes(res['dense_params'] * bytes_per_elem)})")
print(f"{'Per-expert params':23}: {res['per_expert_params']:,} "
f"({bytes_convert(res['per_expert_params'] * bytes_per_elem)})")
f"({convert_bytes(res['per_expert_params'] * bytes_per_elem)})")
print(f"{'Router params':23}: {res['router']:,} "
f"({bytes_convert(res['router'] * bytes_per_elem)})")
f"({convert_bytes(res['router'] * bytes_per_elem)})")
print(f"{'MoE TOTAL params':23}: {res['moe_total']:,} "
f"({bytes_convert(res['moe_total'] * bytes_per_elem)})")
f"({convert_bytes(res['moe_total'] * bytes_per_elem)})")
print(f"{'MoE ACTIVE/Token':23}: {moe_active_params_per_token:,} "
f"({bytes_convert(moe_active_params_per_token * bytes_per_elem)})")
f"({convert_bytes(moe_active_params_per_token * bytes_per_elem)})")
print(f"{'moe_hidden_dim':23}: {res['moe_hidden_dim']}")
print()

View File

@@ -6,14 +6,14 @@
import argparse
import matplotlib.pyplot as plt
from ffn_moe_memory_estimator import (
from memory_estimator_moe import (
estimate_params_and_hidden,
ffn_params,
router_params,
calc_ffn_params,
calc_router_params,
)
def moe_active_and_total(
def calc_moe_active_and_total(
emb_dim,
hidden_dim,
ffn_type,
@@ -22,8 +22,8 @@ def moe_active_and_total(
match_dense=True,
):
if match_dense:
dense_params = ffn_params(emb_dim, hidden_dim, ffn_type)
router = router_params(emb_dim, num_experts)
dense_params = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
router = calc_router_params(emb_dim, num_experts)
if dense_params <= router:
match_dense = False
@@ -52,11 +52,11 @@ def plot_active_params_vs_experts(
experts = [1, 2, 4, 8, 16, 32, 64, 128, 192, 256, 384, 512]
experts = [e for e in experts if e <= max_experts]
dense_active = ffn_params(emb_dim, hidden_dim, ffn_type)
dense_active = calc_ffn_params(emb_dim, hidden_dim, ffn_type)
moe_active = []
moe_total = []
for e in experts:
active, total = moe_active_and_total(
active, total = calc_moe_active_and_total(
emb_dim=emb_dim,
hidden_dim=hidden_dim,
ffn_type=ffn_type,

View File

@@ -17,21 +17,21 @@ DTYPE_BYTES = {
}
def kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_mha(batch, context_length, emb_dim, n_layers, bytes_per_elem, n_heads):
# Full attention (MHA)
d_head = emb_dim // n_heads
per_layer = batch * context_length * n_heads * d_head * 2 * bytes_per_elem
return per_layer * n_layers
def kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
def calc_kv_bytes_total_deltanet_no_conv(batch, emb_dim, n_layers, bytes_per_elem, n_heads):
# Simple Gated DeltaNet (no convolutional mixing)
d_head = emb_dim // n_heads
per_layer = batch * n_heads * d_head * d_head * bytes_per_elem
return per_layer * n_layers
def gb(x):
def convert_to_gb(x):
return x / 1e9
@@ -52,13 +52,13 @@ def main():
# 1) Full attention only
mha_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# 2) DeltaNet only
dnet_bytes_const = kv_bytes_total_deltanet_no_conv(
dnet_bytes_const = calc_kv_bytes_total_deltanet_no_conv(
args.batch, args.emb_dim, args.n_layers,
bytes_per_elem, args.n_heads
)
@@ -68,17 +68,17 @@ def main():
n_mha_layers = args.n_layers / 4
n_dnet_layers = args.n_layers - n_mha_layers
mix_bytes = np.array([
kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
calc_kv_bytes_total_mha(args.batch, int(t), args.emb_dim, n_mha_layers,
bytes_per_elem, args.n_heads)
+ calc_kv_bytes_total_deltanet_no_conv(args.batch, args.emb_dim, n_dnet_layers,
bytes_per_elem, args.n_heads)
for t in ctx
], dtype=float)
# Convert to GB
mha_gb = gb(mha_bytes)
dnet_gb = gb(dnet_bytes)
mix_gb = gb(mix_bytes)
mha_gb = convert_to_gb(mha_bytes)
dnet_gb = convert_to_gb(dnet_bytes)
mix_gb = convert_to_gb(mix_bytes)
# Plot
fig, ax = plt.subplots(figsize=(7, 4.5))

View File

@@ -101,6 +101,7 @@
"id": "0d824183-145c-4865-89e1-1f0d0a338f19"
},
"source": [
"&nbsp;\n",
"## 5.1 Evaluating generative text models"
]
},
@@ -121,6 +122,7 @@
"id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd"
},
"source": [
"&nbsp;\n",
"### 5.1.1 Using GPT to generate text"
]
},
@@ -253,14 +255,6 @@
"- The next chapters on finetuning LLMs will also introduce additional ways to measure model quality"
]
},
{
"cell_type": "markdown",
"id": "955f9e1a-7bf7-40d8-b1fa-eacabdee8d8e",
"metadata": {},
"source": [
"<br>"
]
},
{
"cell_type": "markdown",
"id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98",
@@ -268,6 +262,7 @@
"id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98"
},
"source": [
"&nbsp;\n",
"### 5.1.2 Calculating the text generation loss: cross-entropy and perplexity"
]
},
@@ -763,6 +758,7 @@
"id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487"
},
"source": [
"&nbsp;\n",
"### 5.1.3 Calculating the training and validation set losses"
]
},
@@ -1220,6 +1216,7 @@
"id": "b9339f8d-00cb-4206-af67-58c32bd72055"
},
"source": [
"&nbsp;\n",
"## 5.2 Training an LLM"
]
},
@@ -1490,6 +1487,7 @@
"id": "699f45fc-bf78-42f2-bd24-2355db41b28f"
},
"source": [
"&nbsp;\n",
"## 5.3 Decoding strategies to control randomness"
]
},
@@ -1558,6 +1556,7 @@
"id": "4bb6f380-a798-4fd9-825c-17b7cd29a994",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.1 Temperature scaling"
]
},
@@ -1837,6 +1836,7 @@
"id": "c6e4873e-07e4-4abb-85df-bdaedcc1a6f7",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.2 Top-k sampling"
]
},
@@ -1957,6 +1957,7 @@
"id": "56056503-a15d-4315-a3ff-46647a4c7c45",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.3 Modifying the text generation function"
]
},
@@ -2054,6 +2055,7 @@
"id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b",
"metadata": {},
"source": [
"&nbsp;\n",
"## 5.4 Loading and saving model weights in PyTorch"
]
},
@@ -2164,6 +2166,7 @@
"id": "4194350e-0409-4a63-8ffd-d3a896509032",
"metadata": {},
"source": [
"&nbsp;\n",
"## 5.5 Loading pretrained weights from OpenAI"
]
},
@@ -2615,6 +2618,7 @@
"id": "f2a66474-230d-4180-a8ff-843e04f1f1c4",
"metadata": {},
"source": [
"&nbsp;\n",
"## Summary and takeaways"
]
},

View File

@@ -62,7 +62,8 @@
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"# Exercise 5.1: Temperature-scaled softmax scores and sampling probabilities"
"&nbsp;\n",
"## Exercise 5.1: Temperature-scaled softmax scores and sampling probabilities"
]
},
{
@@ -239,7 +240,8 @@
"id": "b510ffb0-adca-4d64-8a12-38c4646fd736",
"metadata": {},
"source": [
"# Exercise 5.2: Different temperature and top-k settings"
"&nbsp;\n",
"## Exercise 5.2: Different temperature and top-k settings"
]
},
{
@@ -258,7 +260,8 @@
"id": "3f35425d-529d-4179-a1c4-63cb8b25b156",
"metadata": {},
"source": [
"# Exercise 5.3: Deterministic behavior in the decoding functions"
"&nbsp;\n",
"## Exercise 5.3: Deterministic behavior in the decoding functions"
]
},
{
@@ -425,7 +428,8 @@
"id": "6d0480e5-fb4e-41f8-a161-7ac980d71d47",
"metadata": {},
"source": [
"# Exercise 5.4: Continued pretraining"
"&nbsp;\n",
"## Exercise 5.4: Continued pretraining"
]
},
{
@@ -598,7 +602,8 @@
"id": "3384e788-f5a1-407c-8dd1-87959b75026d",
"metadata": {},
"source": [
"# Exercise 5.5: Training and validation set losses of the pretrained model"
"&nbsp;\n",
"## Exercise 5.5: Training and validation set losses of the pretrained model"
]
},
{
@@ -874,7 +879,8 @@
"id": "3a76a1e0-9635-480a-9391-3bda7aea402d",
"metadata": {},
"source": [
"# Exercise 5.6: Trying larger models"
"&nbsp;\n",
"## Exercise 5.6: Trying larger models"
]
},
{
@@ -1028,7 +1034,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -977,7 +977,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -1001,8 +1001,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1659,7 +1659,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -1050,7 +1050,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -1074,8 +1074,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -2120,7 +2120,7 @@
},
"source": [
"&nbsp;\n",
"# Llama 3.1 8B"
"# 6. Llama 3.1 8B"
]
},
{
@@ -2460,7 +2460,7 @@
},
"source": [
"&nbsp;\n",
"# Llama 3.2 1B"
"# 7. Llama 3.2 1B"
]
},
{

View File

@@ -492,7 +492,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -516,8 +516,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
return mod
@@ -51,16 +51,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Llama3Model(dummy_cfg_base)
model = import_notebook_defs.Llama3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers(nb_imports):
def test_llama3_base_equivalence_with_transformers(import_notebook_defs):
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
@@ -80,7 +80,7 @@ def test_llama3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
}
ours = nb_imports.Llama3Model(cfg)
ours = import_notebook_defs.Llama3Model(cfg)
hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
@@ -107,7 +107,7 @@ def test_llama3_base_equivalence_with_transformers(nb_imports):
theirs = LlamaForCausalLM(hf_cfg)
hf_state = theirs.state_dict()
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
import_notebook_defs.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)

View File

@@ -681,7 +681,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -705,8 +705,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -725,7 +725,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -623,7 +623,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -647,8 +647,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -879,7 +879,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1016,7 +1016,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{

View File

@@ -734,7 +734,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -758,8 +758,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -977,7 +977,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1131,7 +1131,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{
@@ -1253,7 +1253,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -676,7 +676,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -700,8 +700,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -731,7 +731,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1064,7 +1064,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3-plus-kvcache.ipynb")
return mod
@@ -58,9 +58,9 @@ def dummy_cfg_moe(dummy_cfg_base):
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Qwen3Model(dummy_cfg_base)
model = import_notebook_defs.Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -68,7 +68,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers(nb_imports):
def test_qwen3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
@@ -89,7 +89,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Qwen3Model(cfg)
model = import_notebook_defs.Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
@@ -114,7 +114,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
import_notebook_defs.load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb")
return mod
@@ -58,9 +58,9 @@ def dummy_cfg_moe(dummy_cfg_base):
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Qwen3Model(dummy_cfg_base)
model = import_notebook_defs.Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -68,7 +68,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers(nb_imports):
def test_qwen3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
@@ -89,7 +89,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Qwen3Model(cfg)
model = import_notebook_defs.Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
@@ -114,7 +114,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
import_notebook_defs.load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -771,7 +771,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -795,8 +795,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1120,7 +1120,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1307,10 +1307,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1358,7 +1358,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -695,7 +695,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -719,8 +719,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1005,7 +1005,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1172,10 +1172,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -904,7 +904,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1269,10 +1269,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1320,7 +1320,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -801,7 +801,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1160,10 +1160,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1211,7 +1211,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -116,11 +116,11 @@ def load_notebook_defs(nb_name="standalone-olmo3.ipynb"):
return import_definitions_from_notebook(nb_dir, nb_name)
def build_olmo3_pair(nb_imports, cfg, hf_checkpoint=None):
def build_olmo3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
if Olmo3ForCausalLM is None:
raise ImportError("transformers is required for the Olmo-3 debugger.")
ours = nb_imports.Olmo3Model(cfg)
ours = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = _hf_config_from_dict(cfg)
if hf_checkpoint:
@@ -133,7 +133,7 @@ def build_olmo3_pair(nb_imports, cfg, hf_checkpoint=None):
hf_model = Olmo3ForCausalLM(hf_cfg)
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_olmo(ours, param_config, hf_model.state_dict())
import_notebook_defs.load_weights_into_olmo(ours, param_config, hf_model.state_dict())
ours.eval()
hf_model.eval()
@@ -271,10 +271,10 @@ if __name__ == "__main__":
if not transformers_available:
raise SystemExit("transformers is not installed; install it to run the debugger.")
nb_imports = load_notebook_defs()
import_notebook_defs = load_notebook_defs()
cfg = yarn_debug_config()
ours_model, hf_model = build_olmo3_pair(nb_imports, cfg)
ours_model, hf_model = build_olmo3_pair(import_notebook_defs, cfg)
torch.manual_seed(0)
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
diffs = layerwise_differences(ours_model, hf_model, input_ids)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3-plus-kv-cache.ipynb")
return mod
@@ -55,9 +55,9 @@ def dummy_cfg_base():
}
@torch.inference_mode()
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Olmo3Model(dummy_cfg_base)
model = import_notebook_defs.Olmo3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -65,7 +65,7 @@ def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_olmo3_base_equivalence_with_transformers(nb_imports):
def test_olmo3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Olmo3Config, Olmo3ForCausalLM
# Tiny config so the test is fast
@@ -99,7 +99,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"rope_local_base": 10_000.0,
}
model = nb_imports.Olmo3Model(cfg)
model = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = Olmo3Config(
vocab_size=cfg["vocab_size"],
@@ -129,7 +129,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"n_layers": cfg["n_layers"],
"hidden_dim": cfg["hidden_dim"],
}
nb_imports.load_weights_into_olmo(model, param_config, hf_state)
import_notebook_defs.load_weights_into_olmo(model, param_config, hf_state)
x = torch.randint(
0,

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3.ipynb")
return mod
@@ -55,9 +55,9 @@ def dummy_cfg_base():
}
@torch.inference_mode()
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Olmo3Model(dummy_cfg_base)
model = import_notebook_defs.Olmo3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -65,7 +65,7 @@ def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_olmo3_base_equivalence_with_transformers(nb_imports):
def test_olmo3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Olmo3Config, Olmo3ForCausalLM
# Tiny config so the test is fast
@@ -99,7 +99,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"rope_local_base": 10_000.0,
}
model = nb_imports.Olmo3Model(cfg)
model = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = Olmo3Config(
vocab_size=cfg["vocab_size"],
@@ -129,7 +129,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"n_layers": cfg["n_layers"],
"hidden_dim": cfg["hidden_dim"],
}
nb_imports.load_weights_into_olmo(model, param_config, hf_state)
import_notebook_defs.load_weights_into_olmo(model, param_config, hf_state)
x = torch.randint(
0,

View File

@@ -86,7 +86,8 @@
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c"
},
"source": [
"## 6.1 Different categories of finetuning"
"&nbsp;\n",
"### 6.1 Different categories of finetuning"
]
},
{
@@ -142,7 +143,8 @@
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## 6.2 Preparing the dataset"
"&nbsp;\n",
"### 6.2 Preparing the dataset"
]
},
{
@@ -699,7 +701,8 @@
"id": "a8d7a0c5-1d5f-458a-b685-3f49520b0094",
"metadata": {},
"source": [
"## 6.3 Creating data loaders"
"&nbsp;\n",
"### 6.3 Creating data loaders"
]
},
{
@@ -1019,7 +1022,8 @@
"id": "d1c4f61a-5f5d-4b3b-97cf-151b617d1d6c"
},
"source": [
"## 6.4 Initializing a model with pretrained weights"
"&nbsp;\n",
"### 6.4 Initializing a model with pretrained weights"
]
},
{
@@ -1219,7 +1223,8 @@
"id": "4c9ae440-32f9-412f-96cf-fd52cc3e2522"
},
"source": [
"## 6.5 Adding a classification head"
"&nbsp;\n",
"### 6.5 Adding a classification head"
]
},
{
@@ -1722,7 +1727,8 @@
"id": "32aa4aef-e1e9-491b-9adf-5aa973e59b8c",
"metadata": {},
"source": [
"## 6.6 Calculating the classification loss and accuracy"
"&nbsp;\n",
"### 6.6 Calculating the classification loss and accuracy"
]
},
{
@@ -2042,7 +2048,8 @@
"id": "456ae0fd-6261-42b4-ab6a-d24289953083"
},
"source": [
"## 6.7 Finetuning the model on supervised data"
"&nbsp;\n",
"### 6.7 Finetuning the model on supervised data"
]
},
{
@@ -2372,7 +2379,8 @@
"id": "a74d9ad7-3ec1-450e-8c9f-4fc46d3d5bb0",
"metadata": {},
"source": [
"## 6.8 Using the LLM as a spam classifier"
"&nbsp;\n",
"### 6.8 Using the LLM as a spam classifier"
]
},
{
@@ -2564,6 +2572,7 @@
"id": "5b70ac71-234f-4eeb-b33d-c62726d50cd4"
},
"source": [
"&nbsp;\n",
"## Summary and takeaways"
]
},

View File

@@ -130,20 +130,20 @@ def download_and_unzip(url, zip_path, extract_to, new_file_path):
print(f"File downloaded and saved as {new_file_path}")
def random_split(df, train_frac, validation_frac):
def random_split(df, train_frac, val_frac):
# Shuffle the entire DataFrame
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Calculate split indices
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
val_end = train_end + int(len(df) * val_frac)
# Split the DataFrame
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
val_df = df[train_end:val_end]
test_df = df[val_end:]
return train_df, validation_df, test_df
return train_df, val_df, test_df
def create_dataset_csvs(new_file_path):
@@ -157,9 +157,9 @@ def create_dataset_csvs(new_file_path):
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
# Sample and save csv files
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df, val_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
val_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)
@@ -611,7 +611,7 @@ if __name__ == "__main__":
base_path = Path(".")
file_names = ["train.csv", "validation.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
if not all_exist:
try:
download_and_unzip(url, zip_path, extract_to, new_file_path)

View File

@@ -144,6 +144,7 @@
"id": "fae87bc1-14ca-4f89-8e12-49f77b0ec00d",
"metadata": {},
"source": [
"&nbsp;\n",
"## Scikit-learn baseline"
]
},
@@ -269,7 +270,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -79,20 +79,20 @@ def download_and_unzip(url, zip_path, extract_to, new_file_path):
print(f"File downloaded and saved as {new_file_path}")
def random_split(df, train_frac, validation_frac):
def random_split(df, train_frac, val_frac):
# Shuffle the entire DataFrame
df = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Calculate split indices
train_end = int(len(df) * train_frac)
validation_end = train_end + int(len(df) * validation_frac)
val_end = train_end + int(len(df) * val_frac)
# Split the DataFrame
train_df = df[:train_end]
validation_df = df[train_end:validation_end]
test_df = df[validation_end:]
val_df = df[train_end:val_end]
test_df = df[val_end:]
return train_df, validation_df, test_df
return train_df, val_df, test_df
def create_dataset_csvs(new_file_path):
@@ -106,9 +106,9 @@ def create_dataset_csvs(new_file_path):
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
# Sample and save csv files
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df, val_df, test_df = random_split(balanced_df, 0.7, 0.1)
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
val_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

View File

@@ -89,6 +89,7 @@
"id": "8bbc68e9-75b3-41f1-ac2c-e071c3cd0813"
},
"source": [
"&nbsp;\n",
"## 7.1 Introduction to instruction finetuning"
]
},
@@ -133,6 +134,7 @@
"id": "5384f0cf-ef3c-4436-a5fa-59bd25649f86"
},
"source": [
"&nbsp;\n",
"## 7.2 Preparing a dataset for supervised instruction finetuning"
]
},
@@ -499,6 +501,7 @@
"id": "fcaaf606-f913-4445-8301-632ae10d387d"
},
"source": [
"&nbsp;\n",
"## 7.3 Organizing data into training batches"
]
},
@@ -1492,6 +1495,7 @@
"id": "d6aad445-8f19-4238-b9bf-db80767fb91a"
},
"source": [
"&nbsp;\n",
"## 7.5 Loading a pretrained LLM"
]
},
@@ -1724,6 +1728,7 @@
"id": "70d27b9d-a942-4cf5-b797-848c5f01e723"
},
"source": [
"&nbsp;\n",
"## 7.6 Finetuning the LLM on instruction data"
]
},
@@ -1995,6 +2000,7 @@
"id": "87b79a47-13f9-4d1f-87b1-3339bafaf2a3"
},
"source": [
"&nbsp;\n",
"## 7.7 Extracting and saving responses"
]
},
@@ -2251,6 +2257,7 @@
"id": "obgoGI89dgPm"
},
"source": [
"&nbsp;\n",
"## 7.8 Evaluating the finetuned LLM"
]
},
@@ -2847,6 +2854,7 @@
"id": "tIbNMluCDjVM"
},
"source": [
"&nbsp;\n",
"### 7.9.1 What's next\n",
"\n",
"- This marks the final chapter of this book\n",
@@ -2857,12 +2865,26 @@
"- An optional step that is sometimes followed after instruction finetuning, as described in this chapter, is preference finetuning\n",
"- Preference finetuning process can be particularly useful for customizing a model to better align with specific user preferences; see the [../04_preference-tuning-with-dpo](../04_preference-tuning-with-dpo) folder if you are interested in this\n",
"\n",
"- This GitHub repository also contains a large selection of additional bonus material you may enjoy; for more information, please see the [Bonus Material](https://github.com/rasbt/LLMs-from-scratch?tab=readme-ov-file#bonus-material) section on this repository's README page\n",
"\n",
"- This GitHub repository also contains a large selection of additional bonus material you may enjoy; for more information, please see the [Bonus Material](https://github.com/rasbt/LLMs-from-scratch?tab=readme-ov-file#bonus-material) section on this repository's README page"
]
},
{
"cell_type": "markdown",
"id": "0e2b7bc2-2e8d-483f-a8f5-e2aa093db189",
"metadata": {},
"source": [
"&nbsp;\n",
"### 7.9.2 Staying up to date in a fast-moving field\n",
"\n",
"- No code in this section\n",
"\n",
"- No code in this section"
]
},
{
"cell_type": "markdown",
"id": "e3d8327d-afb5-4d24-88af-e253889251cf",
"metadata": {},
"source": [
"&nbsp;\n",
"### 7.9.3 Final words\n",
"\n",
"- I hope you enjoyed this journey of implementing an LLM from the ground up and coding the pretraining and finetuning functions\n",

View File

@@ -88,7 +88,8 @@
"id": "8bcdcb34-ac75-4f4f-9505-3ce0666c42d5",
"metadata": {},
"source": [
"## Test OpenAI API"
"&nbsp;\n",
"## 1. Test OpenAI API"
]
},
{
@@ -177,7 +178,8 @@
"id": "162a4739-6f03-4092-a5c2-f57a0b6a4c4d",
"metadata": {},
"source": [
"## Create JSON Entries"
"&nbsp;\n",
"## 2. Create JSON Entries"
]
},
{
@@ -418,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -1303,7 +1303,7 @@
},
"source": [
"&nbsp;\n",
"## 2.4) Creating training, validation, and test set data loaders"
"## 2.4 Creating training, validation, and test set data loaders"
]
},
{

View File

@@ -306,7 +306,7 @@
"metadata": {},
"outputs": [],
"source": [
"def instr_prompt_no_input(ins, outp):\n",
"def build_instruction_reflection_prompt_no_input(ins, outp):\n",
"\n",
" sys_prompt = \"You are a helpful, precise but picky assistant for checking the quality of a given instruction.\"\n",
" prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n",
@@ -356,7 +356,7 @@
"id": "9572a1aa-532a-4a76-9fa3-3b59d996ba13",
"metadata": {},
"source": [
"- We can refine the instruction as follows, using `instr_prompt_no_input` function defined above:"
"- We can refine the instruction as follows, using `build_instruction_reflection_prompt_no_input` function defined above:"
]
},
{
@@ -405,7 +405,7 @@
"source": [
"entry = json_data[2]\n",
"\n",
"system_prompt, prompt = instr_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
"system_prompt, prompt = build_instruction_reflection_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
"output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n",
"\n",
"print(output)"
@@ -430,7 +430,7 @@
"source": [
"import re\n",
"\n",
"def extract_ins(text, no_input=True):\n",
"def extract_instruction_segment(text, no_input=True):\n",
" if '[New Instruction]' in text:\n",
" pattern = r'(\\[New Instruction\\])(.*?)(\\[End\\]|\\[New Answer\\]|New Answer:)'\n",
" else:\n",
@@ -445,7 +445,7 @@
" return seg_ins\n",
"\n",
"\n",
"def extract_oup(text, no_input=True):\n",
"def extract_output_segment(text, no_input=True):\n",
" if '[New Answer]' in text:\n",
" pattern = r'(\\[New Answer\\])(.*?)(\\[End\\]|$)'\n",
" else:\n",
@@ -462,8 +462,8 @@
"def extract_instruction(text):\n",
" if text == '':\n",
" return []\n",
" seg_ins = extract_ins(text, no_input=True)\n",
" seg_oup = extract_oup(text, no_input=True)\n",
" seg_ins = extract_instruction_segment(text, no_input=True)\n",
" seg_oup = extract_output_segment(text, no_input=True)\n",
" return [seg_ins, seg_oup]"
]
},
@@ -561,7 +561,7 @@
"metadata": {},
"outputs": [],
"source": [
"def res_gen_prompt_no_input(ins, outp):\n",
"def build_response_reflection_prompt_no_input(ins, outp):\n",
"\n",
" sys_prompt = \"You are a helpful, precise but picky assistant for checking the quality of the answer to a given instruction.\"\n",
" prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n",
@@ -574,7 +574,7 @@
" return sys_prompt, prompt\n",
"\n",
"\n",
"def res_gen_prompt_input(ins, inp, outp):\n",
"def build_response_reflection_prompt_with_input(ins, inp, outp):\n",
"\n",
" sys_prompt = \"You are a helpful and precise assistant for checking the quality of the answer to a given instruction and its input.\"\n",
" prompt_template = \"[Instruction]\\n{ins}\\n\\n[The Start of Input]\\n{inp}\\n\\n[The End of Input]\\n\\n[The Start of Answer]\\n{outp}\\n\\n[The End of Answer]\\n\\n[System]\\n{criteria}\\n\\n\"\n",
@@ -626,7 +626,7 @@
"source": [
"entry = json_data[2]\n",
"\n",
"system_prompt, prompt = res_gen_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
"system_prompt, prompt = build_response_reflection_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
"output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n",
"\n",
"print(output)"
@@ -750,7 +750,7 @@
" for entry in tqdm(json_data):\n",
" \n",
" if not entry[\"input\"]:\n",
" system_prompt, prompt = instr_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
" system_prompt, prompt = build_instruction_reflection_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
" output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n",
" new_instr, new_outp = extract_instruction(output)\n",
" new_entry = {\"instruction\": new_instr, \"input\": \"\", \"output\": new_outp}\n",
@@ -906,7 +906,7 @@
" for entry in tqdm(json_data):\n",
" \n",
" if not entry[\"input\"]:\n",
" system_prompt, prompt = res_gen_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
" system_prompt, prompt = build_response_reflection_prompt_no_input(ins=entry[\"instruction\"], outp=entry[\"output\"])\n",
" output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n",
" new_response = extract_response(output)\n",
"\n",
@@ -917,7 +917,7 @@
" new_json_data.append(new_entry)\n",
"\n",
" else:\n",
" system_prompt, prompt = res_gen_prompt_input(ins=entry[\"instruction\"], inp=entry[\"input\"], outp=entry[\"output\"])\n",
" system_prompt, prompt = build_response_reflection_prompt_with_input(ins=entry[\"instruction\"], inp=entry[\"input\"], outp=entry[\"output\"])\n",
" output = run_chatgpt(prompt=prompt, client=client, system_prompt=system_prompt)\n",
" new_response = extract_response(output)\n",
"\n",