Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

View File

@@ -101,6 +101,7 @@
"id": "0d824183-145c-4865-89e1-1f0d0a338f19"
},
"source": [
" \n",
"## 5.1 Evaluating generative text models"
]
},
@@ -121,6 +122,7 @@
"id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd"
},
"source": [
" \n",
"### 5.1.1 Using GPT to generate text"
]
},
@@ -253,14 +255,6 @@
"- The next chapters on finetuning LLMs will also introduce additional ways to measure model quality"
]
},
{
"cell_type": "markdown",
"id": "955f9e1a-7bf7-40d8-b1fa-eacabdee8d8e",
"metadata": {},
"source": [
"<br>"
]
},
{
"cell_type": "markdown",
"id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98",
@@ -268,6 +262,7 @@
"id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98"
},
"source": [
"&nbsp;\n",
"### 5.1.2 Calculating the text generation loss: cross-entropy and perplexity"
]
},
@@ -763,6 +758,7 @@
"id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487"
},
"source": [
"&nbsp;\n",
"### 5.1.3 Calculating the training and validation set losses"
]
},
@@ -1220,6 +1216,7 @@
"id": "b9339f8d-00cb-4206-af67-58c32bd72055"
},
"source": [
"&nbsp;\n",
"## 5.2 Training an LLM"
]
},
@@ -1490,6 +1487,7 @@
"id": "699f45fc-bf78-42f2-bd24-2355db41b28f"
},
"source": [
"&nbsp;\n",
"## 5.3 Decoding strategies to control randomness"
]
},
@@ -1558,6 +1556,7 @@
"id": "4bb6f380-a798-4fd9-825c-17b7cd29a994",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.1 Temperature scaling"
]
},
@@ -1837,6 +1836,7 @@
"id": "c6e4873e-07e4-4abb-85df-bdaedcc1a6f7",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.2 Top-k sampling"
]
},
@@ -1957,6 +1957,7 @@
"id": "56056503-a15d-4315-a3ff-46647a4c7c45",
"metadata": {},
"source": [
"&nbsp;\n",
"### 5.3.3 Modifying the text generation function"
]
},
@@ -2054,6 +2055,7 @@
"id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b",
"metadata": {},
"source": [
"&nbsp;\n",
"## 5.4 Loading and saving model weights in PyTorch"
]
},
@@ -2164,6 +2166,7 @@
"id": "4194350e-0409-4a63-8ffd-d3a896509032",
"metadata": {},
"source": [
"&nbsp;\n",
"## 5.5 Loading pretrained weights from OpenAI"
]
},
@@ -2615,6 +2618,7 @@
"id": "f2a66474-230d-4180-a8ff-843e04f1f1c4",
"metadata": {},
"source": [
"&nbsp;\n",
"## Summary and takeaways"
]
},

View File

@@ -62,7 +62,8 @@
"id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
"metadata": {},
"source": [
"# Exercise 5.1: Temperature-scaled softmax scores and sampling probabilities"
"&nbsp;\n",
"## Exercise 5.1: Temperature-scaled softmax scores and sampling probabilities"
]
},
{
@@ -239,7 +240,8 @@
"id": "b510ffb0-adca-4d64-8a12-38c4646fd736",
"metadata": {},
"source": [
"# Exercise 5.2: Different temperature and top-k settings"
"&nbsp;\n",
"## Exercise 5.2: Different temperature and top-k settings"
]
},
{
@@ -258,7 +260,8 @@
"id": "3f35425d-529d-4179-a1c4-63cb8b25b156",
"metadata": {},
"source": [
"# Exercise 5.3: Deterministic behavior in the decoding functions"
"&nbsp;\n",
"## Exercise 5.3: Deterministic behavior in the decoding functions"
]
},
{
@@ -425,7 +428,8 @@
"id": "6d0480e5-fb4e-41f8-a161-7ac980d71d47",
"metadata": {},
"source": [
"# Exercise 5.4: Continued pretraining"
"&nbsp;\n",
"## Exercise 5.4: Continued pretraining"
]
},
{
@@ -598,7 +602,8 @@
"id": "3384e788-f5a1-407c-8dd1-87959b75026d",
"metadata": {},
"source": [
"# Exercise 5.5: Training and validation set losses of the pretrained model"
"&nbsp;\n",
"## Exercise 5.5: Training and validation set losses of the pretrained model"
]
},
{
@@ -874,7 +879,8 @@
"id": "3a76a1e0-9635-480a-9391-3bda7aea402d",
"metadata": {},
"source": [
"# Exercise 5.6: Trying larger models"
"&nbsp;\n",
"## Exercise 5.6: Trying larger models"
]
},
{
@@ -1028,7 +1034,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -977,7 +977,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -1001,8 +1001,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1659,7 +1659,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -1050,7 +1050,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -1074,8 +1074,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -2120,7 +2120,7 @@
},
"source": [
"&nbsp;\n",
"# Llama 3.1 8B"
"# 6. Llama 3.1 8B"
]
},
{
@@ -2460,7 +2460,7 @@
},
"source": [
"&nbsp;\n",
"# Llama 3.2 1B"
"# 7. Llama 3.2 1B"
]
},
{

View File

@@ -492,7 +492,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -516,8 +516,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb")
return mod
@@ -51,16 +51,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Llama3Model(dummy_cfg_base)
model = import_notebook_defs.Llama3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_llama3_base_equivalence_with_transformers(nb_imports):
def test_llama3_base_equivalence_with_transformers(import_notebook_defs):
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
cfg = {
"vocab_size": 257,
@@ -80,7 +80,7 @@ def test_llama3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
}
ours = nb_imports.Llama3Model(cfg)
ours = import_notebook_defs.Llama3Model(cfg)
hf_cfg = LlamaConfig(
vocab_size=cfg["vocab_size"],
@@ -107,7 +107,7 @@ def test_llama3_base_equivalence_with_transformers(nb_imports):
theirs = LlamaForCausalLM(hf_cfg)
hf_state = theirs.state_dict()
nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
import_notebook_defs.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long)
ours_logits = ours(x)

View File

@@ -681,7 +681,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -705,8 +705,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -725,7 +725,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -623,7 +623,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -647,8 +647,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -879,7 +879,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1016,7 +1016,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{

View File

@@ -734,7 +734,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -758,8 +758,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -977,7 +977,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1131,7 +1131,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{
@@ -1253,7 +1253,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -676,7 +676,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -700,8 +700,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -731,7 +731,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1064,7 +1064,7 @@
},
"source": [
"&nbsp;\n",
"# 5. Generate text"
"# 4. Generate text"
]
},
{

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3-plus-kvcache.ipynb")
return mod
@@ -58,9 +58,9 @@ def dummy_cfg_moe(dummy_cfg_base):
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Qwen3Model(dummy_cfg_base)
model = import_notebook_defs.Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -68,7 +68,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers(nb_imports):
def test_qwen3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
@@ -89,7 +89,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Qwen3Model(cfg)
model = import_notebook_defs.Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
@@ -114,7 +114,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
import_notebook_defs.load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb")
return mod
@@ -58,9 +58,9 @@ def dummy_cfg_moe(dummy_cfg_base):
@torch.inference_mode()
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Qwen3Model(dummy_cfg_base)
model = import_notebook_defs.Qwen3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -68,7 +68,7 @@ def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_qwen3_base_equivalence_with_transformers(nb_imports):
def test_qwen3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Qwen3Config, Qwen3ForCausalLM
# Tiny config so the test is fast
@@ -89,7 +89,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Qwen3Model(cfg)
model = import_notebook_defs.Qwen3Model(cfg)
hf_cfg = Qwen3Config(
vocab_size=cfg["vocab_size"],
@@ -114,7 +114,7 @@ def test_qwen3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_qwen(model, param_config, hf_state)
import_notebook_defs.load_weights_into_qwen(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -771,7 +771,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -795,8 +795,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1120,7 +1120,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1307,10 +1307,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1358,7 +1358,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -695,7 +695,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -719,8 +719,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1005,7 +1005,7 @@
"metadata": {},
"source": [
"&nbsp;\n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1172,10 +1172,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -904,7 +904,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1269,10 +1269,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1320,7 +1320,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -801,7 +801,7 @@
},
"source": [
"&nbsp;\n",
"# 4. Load pretrained weights"
"# 3. Load pretrained weights"
]
},
{
@@ -1160,10 +1160,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1211,7 +1211,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -116,11 +116,11 @@ def load_notebook_defs(nb_name="standalone-olmo3.ipynb"):
return import_definitions_from_notebook(nb_dir, nb_name)
def build_olmo3_pair(nb_imports, cfg, hf_checkpoint=None):
def build_olmo3_pair(import_notebook_defs, cfg, hf_checkpoint=None):
if Olmo3ForCausalLM is None:
raise ImportError("transformers is required for the Olmo-3 debugger.")
ours = nb_imports.Olmo3Model(cfg)
ours = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = _hf_config_from_dict(cfg)
if hf_checkpoint:
@@ -133,7 +133,7 @@ def build_olmo3_pair(nb_imports, cfg, hf_checkpoint=None):
hf_model = Olmo3ForCausalLM(hf_cfg)
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_olmo(ours, param_config, hf_model.state_dict())
import_notebook_defs.load_weights_into_olmo(ours, param_config, hf_model.state_dict())
ours.eval()
hf_model.eval()
@@ -271,10 +271,10 @@ if __name__ == "__main__":
if not transformers_available:
raise SystemExit("transformers is not installed; install it to run the debugger.")
nb_imports = load_notebook_defs()
import_notebook_defs = load_notebook_defs()
cfg = yarn_debug_config()
ours_model, hf_model = build_olmo3_pair(nb_imports, cfg)
ours_model, hf_model = build_olmo3_pair(import_notebook_defs, cfg)
torch.manual_seed(0)
input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long)
diffs = layerwise_differences(ours_model, hf_model, input_ids)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3-plus-kv-cache.ipynb")
return mod
@@ -55,9 +55,9 @@ def dummy_cfg_base():
}
@torch.inference_mode()
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Olmo3Model(dummy_cfg_base)
model = import_notebook_defs.Olmo3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -65,7 +65,7 @@ def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_olmo3_base_equivalence_with_transformers(nb_imports):
def test_olmo3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Olmo3Config, Olmo3ForCausalLM
# Tiny config so the test is fast
@@ -99,7 +99,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"rope_local_base": 10_000.0,
}
model = nb_imports.Olmo3Model(cfg)
model = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = Olmo3Config(
vocab_size=cfg["vocab_size"],
@@ -129,7 +129,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"n_layers": cfg["n_layers"],
"hidden_dim": cfg["hidden_dim"],
}
nb_imports.load_weights_into_olmo(model, param_config, hf_state)
import_notebook_defs.load_weights_into_olmo(model, param_config, hf_state)
x = torch.randint(
0,

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3.ipynb")
return mod
@@ -55,9 +55,9 @@ def dummy_cfg_base():
}
@torch.inference_mode()
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Olmo3Model(dummy_cfg_base)
model = import_notebook_defs.Olmo3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \
f"Expected shape (1, seq_len, vocab_size), got {out.shape}"
@@ -65,7 +65,7 @@ def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports):
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_olmo3_base_equivalence_with_transformers(nb_imports):
def test_olmo3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Olmo3Config, Olmo3ForCausalLM
# Tiny config so the test is fast
@@ -99,7 +99,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"rope_local_base": 10_000.0,
}
model = nb_imports.Olmo3Model(cfg)
model = import_notebook_defs.Olmo3Model(cfg)
hf_cfg = Olmo3Config(
vocab_size=cfg["vocab_size"],
@@ -129,7 +129,7 @@ def test_olmo3_base_equivalence_with_transformers(nb_imports):
"n_layers": cfg["n_layers"],
"hidden_dim": cfg["hidden_dim"],
}
nb_imports.load_weights_into_olmo(model, param_config, hf_state)
import_notebook_defs.load_weights_into_olmo(model, param_config, hf_state)
x = torch.randint(
0,