mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
add HF equivalency tests for standalone nbs (#774)
* add HF equivalency tests for standalone nbs * update * update * update * update
This commit is contained in:
committed by
GitHub
parent
a6b883c9f9
commit
80d4732456
@@ -116,7 +116,7 @@ QWEN3_CONFIG_30B_A3B = {
|
||||
"dtype": torch.bfloat16,
|
||||
"num_experts": 128,
|
||||
"num_experts_per_tok": 8,
|
||||
"moe_intermediate_size": 768,
|
||||
"moe_intermediate_size": 768,
|
||||
}
|
||||
|
||||
|
||||
|
||||
124
pkg/llms_from_scratch/utils.py
Normal file
124
pkg/llms_from_scratch/utils.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# Internal utility functions (not intended for public use)
|
||||
|
||||
import ast
|
||||
import re
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import nbformat
|
||||
|
||||
|
||||
def _extract_imports(src: str):
|
||||
out = []
|
||||
try:
|
||||
tree = ast.parse(src)
|
||||
except SyntaxError:
|
||||
return out
|
||||
for node in tree.body:
|
||||
if isinstance(node, ast.Import):
|
||||
parts = []
|
||||
for n in node.names:
|
||||
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
|
||||
out.append("import " + ", ".join(parts))
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
module = node.module or ""
|
||||
parts = []
|
||||
for n in node.names:
|
||||
parts.append(f"{n.name} as {n.asname}" if n.asname else n.name)
|
||||
level = "." * node.level if getattr(node, "level", 0) else ""
|
||||
out.append(f"from {level}{module} import " + ", ".join(parts))
|
||||
return out
|
||||
|
||||
|
||||
def _extract_defs_and_classes_from_code(src):
|
||||
lines = src.splitlines()
|
||||
kept = []
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
line = lines[i]
|
||||
stripped = line.lstrip()
|
||||
if stripped.startswith("@"):
|
||||
j = i + 1
|
||||
while j < len(lines) and not lines[j].strip():
|
||||
j += 1
|
||||
if j < len(lines) and lines[j].lstrip().startswith(("def ", "class ")):
|
||||
kept.append(line)
|
||||
i += 1
|
||||
continue
|
||||
if stripped.startswith("def ") or stripped.startswith("class "):
|
||||
kept.append(line)
|
||||
base_indent = len(line) - len(stripped)
|
||||
i += 1
|
||||
while i < len(lines):
|
||||
nxt = lines[i]
|
||||
if nxt.strip() == "":
|
||||
kept.append(nxt)
|
||||
i += 1
|
||||
continue
|
||||
indent = len(nxt) - len(nxt.lstrip())
|
||||
if indent <= base_indent and not nxt.lstrip().startswith(("#", "@")):
|
||||
break
|
||||
kept.append(nxt)
|
||||
i += 1
|
||||
continue
|
||||
i += 1
|
||||
|
||||
code = "\n".join(kept)
|
||||
|
||||
# General rule:
|
||||
# replace functions defined like `def load_weights_into_xxx(ClassName, ...`
|
||||
# with `def load_weights_into_xxx(model, ...`
|
||||
code = re.sub(
|
||||
r"(def\s+load_weights_into_\w+\s*\()\s*\w+\s*,",
|
||||
r"\1model,",
|
||||
code
|
||||
)
|
||||
return code
|
||||
|
||||
|
||||
def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extra_globals=None):
|
||||
nb_path = Path(nb_dir_or_path)
|
||||
if notebook_name is not None:
|
||||
nb_file = nb_path / notebook_name if nb_path.is_dir() else nb_path
|
||||
else:
|
||||
nb_file = nb_path
|
||||
|
||||
if not nb_file.exists():
|
||||
raise FileNotFoundError(f"Notebook not found: {nb_file}")
|
||||
|
||||
nb = nbformat.read(nb_file, as_version=4)
|
||||
|
||||
import_lines = []
|
||||
seen = set()
|
||||
for cell in nb.cells:
|
||||
if cell.cell_type == "code":
|
||||
for line in _extract_imports(cell.source):
|
||||
if line not in seen:
|
||||
import_lines.append(line)
|
||||
seen.add(line)
|
||||
|
||||
for required in ("import torch", "import torch.nn as nn"):
|
||||
if required not in seen:
|
||||
import_lines.append(required)
|
||||
seen.add(required)
|
||||
|
||||
pieces = []
|
||||
for cell in nb.cells:
|
||||
if cell.cell_type == "code":
|
||||
pieces.append(_extract_defs_and_classes_from_code(cell.source))
|
||||
|
||||
src = "\n\n".join(import_lines + pieces)
|
||||
|
||||
mod_name = nb_file.stem.replace("-", "_").replace(" ", "_") or "notebook_defs"
|
||||
mod = types.ModuleType(mod_name)
|
||||
|
||||
if extra_globals:
|
||||
mod.__dict__.update(extra_globals)
|
||||
|
||||
exec(src, mod.__dict__)
|
||||
return mod
|
||||
Reference in New Issue
Block a user