mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Readability and code quality improvements (#959)
* Consistent dataset naming * consistent section headers
This commit is contained in:
committed by
GitHub
parent
7b1f740f74
commit
be5e2a3331
@@ -771,7 +771,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def model_memory_size(model, input_dtype=torch.float32):\n",
|
||||
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
|
||||
" total_params = 0\n",
|
||||
" total_grads = 0\n",
|
||||
" for param in model.parameters():\n",
|
||||
@@ -795,8 +795,8 @@
|
||||
"\n",
|
||||
" return total_memory_gb\n",
|
||||
"\n",
|
||||
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
||||
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
||||
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
||||
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1120,7 +1120,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"# 4. Load tokenizer"
|
||||
"# 3. Load tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1307,10 +1307,10 @@
|
||||
" )\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" def gpu_gb(x):\n",
|
||||
" def calc_gpu_gb(x):\n",
|
||||
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
|
||||
" \n",
|
||||
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
|
||||
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1358,7 +1358,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -695,7 +695,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def model_memory_size(model, input_dtype=torch.float32):\n",
|
||||
"def calc_model_memory_size(model, input_dtype=torch.float32):\n",
|
||||
" total_params = 0\n",
|
||||
" total_grads = 0\n",
|
||||
" for param in model.parameters():\n",
|
||||
@@ -719,8 +719,8 @@
|
||||
"\n",
|
||||
" return total_memory_gb\n",
|
||||
"\n",
|
||||
"print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
||||
"print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
||||
"print(f\"float32 (PyTorch default): {calc_model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
|
||||
"print(f\"bfloat16: {calc_model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1005,7 +1005,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
" \n",
|
||||
"# 4. Load tokenizer"
|
||||
"# 3. Load tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1172,10 +1172,10 @@
|
||||
" )\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" def gpu_gb(x):\n",
|
||||
" def calc_gpu_gb(x):\n",
|
||||
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
|
||||
" \n",
|
||||
" print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")"
|
||||
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1223,7 +1223,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
"version": "3.13.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nb_imports():
|
||||
def import_notebook_defs():
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3-plus-kvcache.ipynb")
|
||||
return mod
|
||||
@@ -50,16 +50,16 @@ def dummy_cfg_base():
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
|
||||
torch.manual_seed(123)
|
||||
model = nb_imports.Gemma3Model(dummy_cfg_base)
|
||||
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
|
||||
out = model(dummy_input)
|
||||
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
|
||||
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
|
||||
|
||||
# Tiny config so the test is fast
|
||||
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
"dtype": torch.float32,
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
model = nb_imports.Gemma3Model(cfg)
|
||||
model = import_notebook_defs.Gemma3Model(cfg)
|
||||
|
||||
hf_cfg = Gemma3TextConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
|
||||
hf_state = hf_model.state_dict()
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
|
||||
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
|
||||
|
||||
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||
ours_logits = model(x)
|
||||
|
||||
@@ -16,7 +16,7 @@ transformers_installed = importlib.util.find_spec("transformers") is not None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nb_imports():
|
||||
def import_notebook_defs():
|
||||
nb_dir = Path(__file__).resolve().parents[1]
|
||||
mod = import_definitions_from_notebook(nb_dir, "standalone-gemma3.ipynb")
|
||||
return mod
|
||||
@@ -50,16 +50,16 @@ def dummy_cfg_base():
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, nb_imports):
|
||||
def test_dummy_gemma3_forward(dummy_cfg_base, dummy_input, import_notebook_defs):
|
||||
torch.manual_seed(123)
|
||||
model = nb_imports.Gemma3Model(dummy_cfg_base)
|
||||
model = import_notebook_defs.Gemma3Model(dummy_cfg_base)
|
||||
out = model(dummy_input)
|
||||
assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
def test_gemma3_base_equivalence_with_transformers(import_notebook_defs):
|
||||
from transformers import Gemma3TextConfig, Gemma3ForCausalLM
|
||||
|
||||
# Tiny config so the test is fast
|
||||
@@ -80,7 +80,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
"dtype": torch.float32,
|
||||
"query_pre_attn_scalar": 256,
|
||||
}
|
||||
model = nb_imports.Gemma3Model(cfg)
|
||||
model = import_notebook_defs.Gemma3Model(cfg)
|
||||
|
||||
hf_cfg = Gemma3TextConfig(
|
||||
vocab_size=cfg["vocab_size"],
|
||||
@@ -105,7 +105,7 @@ def test_gemma3_base_equivalence_with_transformers(nb_imports):
|
||||
|
||||
hf_state = hf_model.state_dict()
|
||||
param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]}
|
||||
nb_imports.load_weights_into_gemma(model, param_config, hf_state)
|
||||
import_notebook_defs.load_weights_into_gemma(model, param_config, hf_state)
|
||||
|
||||
x = torch.randint(0, cfg["vocab_size"], (2, cfg["context_length"]), dtype=torch.long)
|
||||
ours_logits = model(x)
|
||||
|
||||
Reference in New Issue
Block a user