Readability and code quality improvements (#959)

* Consistent dataset naming

* consistent section headers
This commit is contained in:
Sebastian Raschka
2026-02-17 19:44:56 -05:00
committed by GitHub
parent 7b1f740f74
commit be5e2a3331
48 changed files with 419 additions and 297 deletions

View File

@@ -771,7 +771,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -795,8 +795,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1120,7 +1120,7 @@
"metadata": {},
"source": [
" \n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1307,10 +1307,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1358,7 +1358,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -695,7 +695,7 @@
}
],
"source": [
"def model_memory_size(model, input_dtype=torch.float32):\n",
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
" total_params = 0\n",
" total_grads = 0\n",
" for param in model.parameters():\n",
@@ -719,8 +719,8 @@
"\n",
" return total_memory_gb\n",
"\n",
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
]
},
{
@@ -1005,7 +1005,7 @@
"metadata": {},
"source": [
" \n",
"# 4. Load tokenizer"
"# 3. Load tokenizer"
]
},
{
@@ -1172,10 +1172,10 @@
" )\n",
"\n",
"if torch.cuda.is_available():\n",
" def gpu_gb(x):\n",
" def calc_gpu_gb(x):\n",
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
]
},
{
@@ -1223,7 +1223,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.13.5"
}
},
"nbformat": 4,

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)

View File

@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
@pytest.fixture
def nb_imports():
def import_notebook_defs():
nb_dir = Path(__file__).resolve().parents[1]
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
return mod
@@ -50,16 +50,16 @@ def dummy_cfg_base():
@torch.inference_mode()
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
torch.manual_seed(123)
model = nb_imports.Gemma3Model(dummy_cfg_base)
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
out = model(dummy_input)
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
@torch.inference_mode()
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_gemma3_base_equivalence_with_transformers(nb_imports):
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
# Tiny config so the test is fast
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
}
model = nb_imports.Gemma3Model(cfg)
model = import_notebook_defs.Gemma3Model(cfg)
hf_cfg = Gemma3TextConfig(
vocab_size=cfg["vocab_size"],
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
hf_state = hf_model.state_dict()
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
ours_logits = model(x)