From 80d4732456878238fcfe8e2495d696ea7d5b10f6 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 18 Aug 2025 18:58:46 -0500 Subject: [PATCH] add HF equivalency tests for standalone nbs (#774) * add HF equivalency tests for standalone nbs * update * update * update * update --- .github/workflows/basic-tests-linux-uv.yml | 6 +- .github/workflows/basic-tests-macos-uv.yml | 6 +- .github/workflows/basic-tests-old-pytorch.yml | 1 - .github/workflows/basic-tests-pip.yml | 2 - .github/workflows/basic-tests-pixi.yml | 1 - .github/workflows/basic-tests-pytorch-rc.yml | 2 - .../workflows/basic-tests-windows-uv-pip.yml | 5 +- .../basic-tests-windows-uv.yml.disabled | 1 - ch05/07_gpt_to_llama/tests/test_llama32_nb.py | 116 ++++++++++++++++ .../{tests.py => tests_rope_and_parts.py} | 0 ch05/11_qwen3/tests/test_qwen3_nb.py | 122 +++++++++++++++++ .../{test_gemma3.py => test_gemma3_nb.py} | 91 ++----------- pkg/llms_from_scratch/qwen3.py | 2 +- pkg/llms_from_scratch/utils.py | 124 ++++++++++++++++++ pyproject.toml | 1 + 15 files changed, 389 insertions(+), 91 deletions(-) create mode 100644 ch05/07_gpt_to_llama/tests/test_llama32_nb.py rename ch05/07_gpt_to_llama/tests/{tests.py => tests_rope_and_parts.py} (100%) create mode 100644 ch05/11_qwen3/tests/test_qwen3_nb.py rename ch05/12_gemma3/tests/{test_gemma3.py => test_gemma3_nb.py} (54%) create mode 100644 pkg/llms_from_scratch/utils.py diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index b11db40..f02f937 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -51,8 +51,10 @@ jobs: pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch04/03_kv-cache/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py - pytest --ruff ch05/12_gemma3/tests/test_gemma3.py + pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py + pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py + pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py + pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks (uv) diff --git a/.github/workflows/basic-tests-macos-uv.yml b/.github/workflows/basic-tests-macos-uv.yml index 454cc35..a3f052e 100644 --- a/.github/workflows/basic-tests-macos-uv.yml +++ b/.github/workflows/basic-tests-macos-uv.yml @@ -50,8 +50,10 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py - pytest --ruff ch05/12_gemma3/tests/test_gemma3.py + pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py + pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py + pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py + pytest --ruff ch05/12_gemma3/tests/test_gemma3_nb.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks (uv) diff --git a/.github/workflows/basic-tests-old-pytorch.yml b/.github/workflows/basic-tests-old-pytorch.yml index 0bb0ac0..08fd634 100644 --- a/.github/workflows/basic-tests-old-pytorch.yml +++ b/.github/workflows/basic-tests-old-pytorch.yml @@ -47,7 +47,6 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-pip.yml b/.github/workflows/basic-tests-pip.yml index 01784d4..b1d74e3 100644 --- a/.github/workflows/basic-tests-pip.yml +++ b/.github/workflows/basic-tests-pip.yml @@ -41,7 +41,6 @@ jobs: source .venv/bin/activate pip install --upgrade pip pip install -r requirements.txt - pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt pip install pytest pytest-ruff nbval - name: Test Selected Python Scripts @@ -50,7 +49,6 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-pixi.yml b/.github/workflows/basic-tests-pixi.yml index a296d50..85eba28 100644 --- a/.github/workflows/basic-tests-pixi.yml +++ b/.github/workflows/basic-tests-pixi.yml @@ -50,7 +50,6 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-pytorch-rc.yml b/.github/workflows/basic-tests-pytorch-rc.yml index 4e8da01..536125c 100644 --- a/.github/workflows/basic-tests-pytorch-rc.yml +++ b/.github/workflows/basic-tests-pytorch-rc.yml @@ -33,7 +33,6 @@ jobs: run: | curl -LsSf https://astral.sh/uv/install.sh | sh uv sync --dev --python=3.10 # tests for backwards compatibility - uv pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt uv add pytest-ruff nbval uv pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu @@ -43,7 +42,6 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks diff --git a/.github/workflows/basic-tests-windows-uv-pip.yml b/.github/workflows/basic-tests-windows-uv-pip.yml index 5f0fe92..8836332 100644 --- a/.github/workflows/basic-tests-windows-uv-pip.yml +++ b/.github/workflows/basic-tests-windows-uv-pip.yml @@ -43,6 +43,7 @@ jobs: pip install tensorflow-io-gcs-filesystem==0.31.0 # Explicit for Windows pip install -r ch05/07_gpt_to_llama/tests/test-requirements-extra.txt pip install pytest-ruff nbval + pip install -e . - name: Run Python Tests shell: bash @@ -51,7 +52,9 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py + pytest --ruff ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py + pytest --ruff ch05/07_gpt_to_llama/tests/test_llama32_nb.py + pytest --ruff ch05/11_qwen3/tests/test_qwen3_nb.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Run Jupyter Notebook Tests diff --git a/.github/workflows/basic-tests-windows-uv.yml.disabled b/.github/workflows/basic-tests-windows-uv.yml.disabled index 306690a..d06ac8f 100644 --- a/.github/workflows/basic-tests-windows-uv.yml.disabled +++ b/.github/workflows/basic-tests-windows-uv.yml.disabled @@ -51,7 +51,6 @@ jobs: pytest --ruff setup/02_installing-python-libraries/tests.py pytest --ruff ch04/01_main-chapter-code/tests.py pytest --ruff ch05/01_main-chapter-code/tests.py - pytest --ruff ch05/07_gpt_to_llama/tests/tests.py pytest --ruff ch06/01_main-chapter-code/tests.py - name: Run Jupyter Notebook Tests diff --git a/ch05/07_gpt_to_llama/tests/test_llama32_nb.py b/ch05/07_gpt_to_llama/tests/test_llama32_nb.py new file mode 100644 index 0000000..234b84c --- /dev/null +++ b/ch05/07_gpt_to_llama/tests/test_llama32_nb.py @@ -0,0 +1,116 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +transformers_installed = importlib.util.find_spec("transformers") is not None + + +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-llama32.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8 + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "emb_dim": 32, # hidden_size + "hidden_dim": 64, # intermediate_size (FFN) + "n_layers": 2, + "n_heads": 4, + "head_dim": 8, + "n_kv_groups": 1, + "dtype": torch.float32, + "rope_base": 500_000.0, + "rope_freq": { + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_context_length": 8192, + }, + "context_length": 64, + } + + +@torch.inference_mode() +def test_dummy_llama3_forward(dummy_cfg_base, dummy_input, nb_imports): + torch.manual_seed(123) + model = nb_imports.Llama3Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]) + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_llama3_base_equivalence_with_transformers(nb_imports): + from transformers.models.llama import LlamaConfig, LlamaForCausalLM + cfg = { + "vocab_size": 257, + "context_length": 8192, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "n_kv_groups": 2, + "rope_base": 500_000.0, + "rope_freq": { + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_context_length": 8192, + }, + "dtype": torch.float32, + } + + ours = nb_imports.Llama3Model(cfg) + + hf_cfg = LlamaConfig( + vocab_size=cfg["vocab_size"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_key_value_heads=cfg["n_kv_groups"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + max_position_embeddings=cfg["context_length"], + rms_norm_eps=1e-5, + attention_bias=False, + rope_theta=cfg["rope_base"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + rope_scaling={ + "type": "llama3", + "factor": cfg["rope_freq"]["factor"], + "low_freq_factor": cfg["rope_freq"]["low_freq_factor"], + "high_freq_factor": cfg["rope_freq"]["high_freq_factor"], + "original_max_position_embeddings": cfg["rope_freq"]["original_context_length"], + }, + ) + theirs = LlamaForCausalLM(hf_cfg) + + hf_state = theirs.state_dict() + nb_imports.load_weights_into_llama(ours, {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, 8), dtype=torch.long) + ours_logits = ours(x) + theirs_logits = theirs(x).logits.to(ours_logits.dtype) + + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py similarity index 100% rename from ch05/07_gpt_to_llama/tests/tests.py rename to ch05/07_gpt_to_llama/tests/tests_rope_and_parts.py diff --git a/ch05/11_qwen3/tests/test_qwen3_nb.py b/ch05/11_qwen3/tests/test_qwen3_nb.py new file mode 100644 index 0000000..2b7ecce --- /dev/null +++ b/ch05/11_qwen3/tests/test_qwen3_nb.py @@ -0,0 +1,122 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +transformers_installed = importlib.util.find_spec("transformers") is not None + + +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-qwen3.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8 + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "emb_dim": 32, + "hidden_dim": 64, + "n_layers": 2, + "n_heads": 4, + "head_dim": 8, + "n_kv_groups": 1, + "qk_norm": False, + "dtype": torch.float32, + "rope_base": 10000, + "context_length": 64, + "num_experts": 0, + } + + +@pytest.fixture +def dummy_cfg_moe(dummy_cfg_base): + cfg = dummy_cfg_base.copy() + cfg.update({ + "num_experts": 4, + "num_experts_per_tok": 2, + "moe_intermediate_size": 64, + }) + return cfg + + +@torch.inference_mode() +def test_dummy_qwen3_forward(dummy_cfg_base, dummy_input, nb_imports): + torch.manual_seed(123) + model = nb_imports.Qwen3Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \ + f"Expected shape (1, seq_len, vocab_size), got {out.shape}" + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_qwen3_base_equivalence_with_transformers(nb_imports): + from transformers import Qwen3Config, Qwen3ForCausalLM + + # Tiny config so the test is fast + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_groups": 2, + "rope_base": 1_000_000.0, + "rope_local_base": 10_000.0, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + } + model = nb_imports.Qwen3Model(cfg) + + hf_cfg = Qwen3Config( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_groups"], + rope_theta=cfg["rope_base"], + rope_local_base_freq=cfg["rope_local_base"], + layer_types=cfg["layer_types"], + sliding_window=cfg["sliding_window"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + query_pre_attn_scalar=cfg["query_pre_attn_scalar"], + rope_scaling={"rope_type": "default"}, + ) + hf_model = Qwen3ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]} + nb_imports.load_weights_into_qwen(model, param_config, hf_state) + + x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long) + ours_logits = model(x) + theirs_logits = hf_model(x).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) diff --git a/ch05/12_gemma3/tests/test_gemma3.py b/ch05/12_gemma3/tests/test_gemma3_nb.py similarity index 54% rename from ch05/12_gemma3/tests/test_gemma3.py rename to ch05/12_gemma3/tests/test_gemma3_nb.py index 129cbc9..dd928b9 100644 --- a/ch05/12_gemma3/tests/test_gemma3.py +++ b/ch05/12_gemma3/tests/test_gemma3_nb.py @@ -4,77 +4,21 @@ # Code: https://github.com/rasbt/LLMs-from-scratch import importlib -import types -import re from pathlib import Path -import nbformat import pytest import torch +from llms_from_scratch.utils import import_definitions_from_notebook + + transformers_installed = importlib.util.find_spec("transformers") is not None -def _extract_defs_and_classes_from_code(src): - lines = src.splitlines() - kept = [] - i = 0 - while i < len(lines): - line = lines[i] - stripped = line.lstrip() - # Keep decorators attached to the next def/class - if stripped.startswith("@"): - # Look ahead: if the next non-empty line starts with def/class, keep decorator - j = i + 1 - while j < len(lines) and not lines[j].strip(): - j += 1 - if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")): - kept.append(line) - i += 1 - continue - if stripped.startswith("def ") or stripped.startswith("class "): - kept.append(line) - # capture until we leave the indentation block - base_indent = len(line) - len(stripped) - i += 1 - while i < len(lines): - nxt = lines[i] - if nxt.strip() == "": - kept.append(nxt) - i += 1 - continue - indent = len(nxt) - len(nxt.lstrip()) - if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")): - break - kept.append(nxt) - i += 1 - continue - i += 1 - code = "\n".join(kept) - code = re.sub(r"def\s+load_weights_into_gemma\s*\(\s*Gemma3Model\s*,", - "def load_weights_into_gemma(model,", - code) - return code - - -def import_definitions_from_notebook(nb_dir_or_path, notebook_name): - nb_path = Path(nb_dir_or_path) - if nb_path.is_dir(): - nb_file = nb_path / notebook_name - else: - nb_file = nb_path - if not nb_file.exists(): - raise FileNotFoundError(f"Notebook not found: {nb_file}") - - nb = nbformat.read(nb_file, as_version=4) - pieces = ["import torch", "import torch.nn as nn"] - for cell in nb.cells: - if cell.cell_type == "code": - pieces.append(_extract_defs_and_classes_from_code(cell.source)) - src = "\n\n".join(pieces) - - mod = types.ModuleType("gemma3_defs") - exec(src, mod.__dict__) +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb") return mod @@ -106,25 +50,16 @@ def dummy_cfg_base(): @torch.inference_mode() -def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input): - nb_dir = Path(__file__).resolve().parents[1] - mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb") - Gemma3Model = mod.Gemma3Model - +def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports): torch.manual_seed(123) - model = Gemma3Model(dummy_cfg_base) + model = nb_imports.Gemma3Model(dummy_cfg_base) out = model(dummy_input) - assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), f"Expected shape (1, seq_len, vocab_size), got {out.shape}" + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]) @torch.inference_mode() @pytest.mark.skipif(not transformers_installed, reason="transformers not installed") -def test_gemma3_base_equivalence_with_transformers(): - nb_dir = Path(__file__).resolve().parents[1] - mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb") - Gemma3Model = mod.Gemma3Model - load_weights_into_gemma = mod.load_weights_into_gemma - +def test_gemma3_base_equivalence_with_transformers(nb_imports): from transformers import Gemma3TextConfig, Gemma3ForCausalLM # Tiny config so the test is fast @@ -145,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(): "dtype": torch.float32, "query_pre_attn_scalar": 256, } - model = Gemma3Model(cfg) + model = nb_imports.Gemma3Model(cfg) hf_cfg = Gemma3TextConfig( vocab_size=cfg["vocab_size"], @@ -170,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(): hf_state = hf_model.state_dict() param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]} - load_weights_into_gemma(model, param_config, hf_state) + nb_imports.load_weights_into_gemma(model, param_config, hf_state) x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long) ours_logits = model(x) diff --git a/pkg/llms_from_scratch/qwen3.py b/pkg/llms_from_scratch/qwen3.py index dd43645..475e972 100644 --- a/pkg/llms_from_scratch/qwen3.py +++ b/pkg/llms_from_scratch/qwen3.py @@ -116,7 +116,7 @@ QWEN3_CONFIG_30B_A3B = { "dtype": torch.bfloat16, "num_experts": 128, "num_experts_per_tok": 8, - "moe_intermediate_size": 768, + "moe_intermediate_size": 768, } diff --git a/pkg/llms_from_scratch/utils.py b/pkg/llms_from_scratch/utils.py new file mode 100644 index 0000000..466ca4c --- /dev/null +++ b/pkg/llms_from_scratch/utils.py @@ -0,0 +1,124 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# Internal utility functions (not intended for public use) + +import ast +import re +import types +from pathlib import Path + +import nbformat + + +def _extract_imports(src: str): + out = [] + try: + tree = ast.parse(src) + except SyntaxError: + return out + for node in tree.body: + if isinstance(node, ast.Import): + parts = [] + for n in node.names: + parts.append(f"{n.name} as {n.asname}" if n.asname else n.name) + out.append("import " + ", ".join(parts)) + elif isinstance(node, ast.ImportFrom): + module = node.module or "" + parts = [] + for n in node.names: + parts.append(f"{n.name} as {n.asname}" if n.asname else n.name) + level = "." * node.level if getattr(node, "level", 0) else "" + out.append(f"from {level}{module} import " + ", ".join(parts)) + return out + + +def _extract_defs_and_classes_from_code(src): + lines = src.splitlines() + kept = [] + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.lstrip() + if stripped.startswith("@"): + j = i + 1 + while j < len(lines) and not lines[j].strip(): + j += 1 + if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")): + kept.append(line) + i += 1 + continue + if stripped.startswith("def ") or stripped.startswith("class "): + kept.append(line) + base_indent = len(line) - len(stripped) + i += 1 + while i < len(lines): + nxt = lines[i] + if nxt.strip() == "": + kept.append(nxt) + i += 1 + continue + indent = len(nxt) - len(nxt.lstrip()) + if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")): + break + kept.append(nxt) + i += 1 + continue + i += 1 + + code = "\n".join(kept) + + # General rule: + # replace functions defined like `def load_weights_into_xxx(ClassName, ...` + # with `def load_weights_into_xxx(model, ...` + code = re.sub( + r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,", + r"\1model,", + code + ) + return code + + +def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None): + nb_path = Path(nb_dir_or_path) + if notebook_name is not None: + nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path + else: + nb_file = nb_path + + if not nb_file.exists(): + raise FileNotFoundError(f"Notebook not found: {nb_file}") + + nb = nbformat.read(nb_file, as_version=4) + + import_lines = [] + seen = set() + for cell in nb.cells: + if cell.cell_type == "code": + for line in _extract_imports(cell.source): + if line not in seen: + import_lines.append(line) + seen.add(line) + + for required in ("import torch", "import torch.nn as nn"): + if required not in seen: + import_lines.append(required) + seen.add(required) + + pieces = [] + for cell in nb.cells: + if cell.cell_type == "code": + pieces.append(_extract_defs_and_classes_from_code(cell.source)) + + src = "\n\n".join(import_lines + pieces) + + mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs" + mod = types.ModuleType(mod_name) + + if extra_globals: + mod.__dict__.update(extra_globals) + + exec(src, mod.__dict__) + return mod diff --git a/pyproject.toml b/pyproject.toml index d1e95fc..cdafa4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "llms-from-scratch", "twine>=6.1.0", "tokenizers>=0.21.1", + "safetensors>=0.6.2", ] [tool.ruff]