From bc6f3355267767b1342cb3b0c70a2f20fc520606 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sat, 22 Nov 2025 22:42:18 -0600 Subject: [PATCH] Olmo 3 from scratch (#914) * Olmo 3 from scratch * update * update * update --- .github/workflows/basic-tests-linux-uv.yml | 2 + .gitignore | 10 + README.md | 24 +- .../standalone-qwen3-moe-plus-kvcache.ipynb | 2 +- .../standalone-qwen3-plus-kvcache.ipynb | 2 +- ch05/11_qwen3/standalone-qwen3.ipynb | 2 +- .../standalone-gemma3-plus-kvcache.ipynb | 115 +- ch05/12_gemma3/standalone-gemma3.ipynb | 52 +- .../standalone-olmo3-plus-kv-cache.ipynb | 1290 +++++++++++++++++ ch05/13_olmo3/standalone-olmo3.ipynb | 1183 +++++++++++++++ ch05/13_olmo3/tests/olmo3_layer_debugger.py | 240 +++ ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py | 142 ++ ch05/13_olmo3/tests/test_olmo3_nb.py | 142 ++ ch05/README.md | 15 +- 14 files changed, 3163 insertions(+), 58 deletions(-) create mode 100644 ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb create mode 100644 ch05/13_olmo3/standalone-olmo3.ipynb create mode 100644 ch05/13_olmo3/tests/olmo3_layer_debugger.py create mode 100644 ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py create mode 100644 ch05/13_olmo3/tests/test_olmo3_nb.py diff --git a/.github/workflows/basic-tests-linux-uv.yml b/.github/workflows/basic-tests-linux-uv.yml index f050e04..1ef96b2 100644 --- a/.github/workflows/basic-tests-linux-uv.yml +++ b/.github/workflows/basic-tests-linux-uv.yml @@ -57,6 +57,8 @@ jobs: pytest ch05/11_qwen3/tests/test_qwen3_nb.py pytest ch05/12_gemma3/tests/test_gemma3_nb.py pytest ch05/12_gemma3/tests/test_gemma3_kv_nb.py + pytest ch05/13_olmo3/tests/test_olmo3_nb.py + pytest ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py pytest ch06/01_main-chapter-code/tests.py - name: Validate Selected Jupyter Notebooks (uv) diff --git a/.gitignore b/.gitignore index dc41278..8cdcb67 100644 --- a/.gitignore +++ b/.gitignore @@ -70,6 +70,16 @@ ch05/11_qwen3/Qwen3-8B ch05/11_qwen3/Qwen3-8B-Base ch05/11_qwen3/Qwen3-32B ch05/11_qwen3/Qwen3-32B-Base +ch05/12_gemma3/gemma-3-270M-it +ch05/12_gemma3/gemma-3-270M +ch05/13_olmo3/Olmo-3-1025-7B +ch05/13_olmo3/Olmo-3-1125-32B +ch05/13_olmo3/Olmo-3-7B-Instruct +ch05/13_olmo3/Olmo-3-32B-Instruct +ch05/13_olmo3/Olmo-3-7B-Think +ch05/13_olmo3/Olmo-3-32B-Think +ch05/13_olmo3/Olmo-3-7B-RLZero-IF +ch05/13_olmo3/Olmo-3-32B-RLZero-IF ch06/01_main-chapter-code/gpt2 ch06/02_bonus_additional-experiments/gpt2 diff --git a/README.md b/README.md index db1ddff..66c1873 100644 --- a/README.md +++ b/README.md @@ -179,19 +179,19 @@ Several folders contain optional materials as a bonus for interested readers: - [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning) - [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface) - [Converting GPT to Llama](ch05/07_gpt_to_llama) - - [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb) - - [Qwen3 Dense and Mixture-of-Experts (MoE) From Scratch](ch05/11_qwen3/) - - [Gemma 3 From Scratch](ch05/12_gemma3/) - - [Memory-Efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb) - - [Extending the Tiktoken BPE Tokenizer With New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb) + - [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb) + - [Extending the Tiktoken BPE Tokenizer with New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb) - [PyTorch Performance Tips for Faster LLM Training](ch05/10_llm-training-speed) - -- **Chapter 6: Finetuning for Classification** - - [Additional Experiments Finetuning Different Layers and Using Larger Models](ch06/02_bonus_additional-experiments) - - [Finetuning Different Models on 50k IMDb Movie Review Dataset](ch06/03_bonus_imdb-classification) - - [Building a User Interface to Interact With the GPT-Based Spam Classifier](ch06/04_user_interface) - -- **Chapter 7: Finetuning to Follow Instructions** + - [LLM Architectures](ch05/#llm-architectures-from-scratch) + - [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb) + - [Qwen3 Dense and Mixture-of-Experts (MoE) From Scratch](ch05/11_qwen3/) + - [Gemma 3 From Scratch](ch05/12_gemma3/) + - [Olmo 3 From Scratch](ch05/13_olmo3/) +- **Chapter 6: Finetuning for classification** + - [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments) + - [Finetuning different models on 50k IMDb movie review dataset](ch06/03_bonus_imdb-classification) + - [Building a User Interface to Interact With the GPT-based Spam Classifier](ch06/04_user_interface) +- **Chapter 7: Finetuning to follow instructions** - [Dataset Utilities for Finding Near Duplicates and Creating Passive Voice Entries](ch07/02_dataset-utilities) - [Evaluating Instruction Responses Using the OpenAI API and Ollama](ch07/03_model-evaluation) - [Generating a Dataset for Instruction Finetuning](ch07/05_dataset-generation/llama3-ollama.ipynb) diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 56e29e7..871a085 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -1223,7 +1223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb index b79af45..bbf86ad 100644 --- a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb @@ -1253,7 +1253,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/ch05/11_qwen3/standalone-qwen3.ipynb b/ch05/11_qwen3/standalone-qwen3.ipynb index 156854f..80c0e3f 100644 --- a/ch05/11_qwen3/standalone-qwen3.ipynb +++ b/ch05/11_qwen3/standalone-qwen3.ipynb @@ -1179,7 +1179,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb index a496dc0..a90783e 100644 --- a/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb +++ b/ch05/12_gemma3/standalone-gemma3-plus-kvcache.ipynb @@ -78,9 +78,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.34.4\n", - "tokenizers version: 0.21.4\n", - "torch version: 2.8.0\n" + "huggingface_hub version: 0.35.0\n", + "tokenizers version: 0.22.1\n", + "torch version: 2.9.0+cu130\n" ] } ], @@ -700,9 +700,9 @@ { "data": { "text/plain": [ - "tensor([[[ 0.7500, 0.1060, 0.4844, ..., 0.9414, 0.3984, -0.2324],\n", - " [-0.3438, -0.0549, 0.8984, ..., -0.2402, 0.4570, 0.8242],\n", - " [-0.2676, -0.3281, 0.4121, ..., 0.8711, -0.9648, 0.9844]]],\n", + "tensor([[[ 0.7500, 0.1011, 0.4863, ..., 0.9414, 0.3984, -0.2285],\n", + " [-0.3398, -0.0564, 0.9023, ..., -0.2480, 0.4551, 0.8203],\n", + " [-0.2695, -0.3242, 0.4121, ..., 0.8672, -0.9688, 0.9844]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, @@ -806,7 +806,20 @@ "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n", + " Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n", + " Minimum and Maximum cuda capability supported by this version of PyTorch is\n", + " (8.0) - (12.0)\n", + " \n", + " warnings.warn(\n" + ] + } + ], "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", @@ -1038,6 +1051,20 @@ "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" }, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3396c08eab3f4cf980023483b969a337", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/536M [00:00\")[-1]\n", "):\n", " token_id = token.squeeze(0).tolist()\n", @@ -1248,7 +1304,13 @@ " tokenizer.decode(token_id),\n", " end=\"\",\n", " flush=True\n", - " )" + " )\n", + "\n", + "if torch.cuda.is_available():\n", + " def gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + " \n", + " print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")" ] }, { @@ -1269,7 +1331,6 @@ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" }, "source": [ - "- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n", "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", "\n", "" @@ -1297,7 +1358,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/ch05/12_gemma3/standalone-gemma3.ipynb b/ch05/12_gemma3/standalone-gemma3.ipynb index 5c45e20..6e5d870 100644 --- a/ch05/12_gemma3/standalone-gemma3.ipynb +++ b/ch05/12_gemma3/standalone-gemma3.ipynb @@ -41,7 +41,6 @@ "source": [ "- This notebook is purposefully minimal and focuses on the code to re-implement Gemma 3 270M in pure PyTorch without relying on other external LLM libraries\n", "- For more information, see the official [Gemma 3 270M model card](https://huggingface.co/google/gemma-3-270m)\n", - "\n", "- Below is a side-by-side comparison with Qwen3 0.6B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n", "
\n", "\n", @@ -78,9 +77,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "huggingface_hub version: 0.34.4\n", - "tokenizers version: 0.21.4\n", - "torch version: 2.8.0\n" + "huggingface_hub version: 0.35.0\n", + "tokenizers version: 0.22.1\n", + "torch version: 2.9.0+cu130\n" ] } ], @@ -628,9 +627,9 @@ { "data": { "text/plain": [ - "tensor([[[ 0.7500, 0.1060, 0.4844, ..., 0.9414, 0.3984, -0.2324],\n", - " [-0.3438, -0.0549, 0.8984, ..., -0.2402, 0.4570, 0.8242],\n", - " [-0.2676, -0.3281, 0.4121, ..., 0.8711, -0.9648, 0.9844]]],\n", + "tensor([[[ 0.7500, 0.1011, 0.4863, ..., 0.9414, 0.3984, -0.2285],\n", + " [-0.3398, -0.0564, 0.9023, ..., -0.2480, 0.4551, 0.8203],\n", + " [-0.2695, -0.3242, 0.4121, ..., 0.8672, -0.9688, 0.9844]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, @@ -731,7 +730,20 @@ "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n", + " Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n", + " Minimum and Maximum cuda capability supported by this version of PyTorch is\n", + " (8.0) - (12.0)\n", + " \n", + " warnings.warn(\n" + ] + } + ], "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", @@ -1095,7 +1107,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1121,7 +1133,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1131,7 +1143,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within language, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n" + "Large language models (LLMs) are sophisticated artificial intelligence systems that can understand, generate, and manipulate human language. They are trained on massive amounts of text data to learn patterns and relationships within that data, enabling them to perform a wide range of tasks, from writing articles and answering questions to translating languages and summarizing information.\n", + "\n", + "\n", + "GPU memory used: 1.04 GB\n" ] } ], @@ -1139,6 +1154,10 @@ "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", "\n", "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "\n", "for token in generate_text_basic_stream(\n", " model=model,\n", " token_ids=input_token_ids_tensor,\n", @@ -1150,7 +1169,13 @@ " tokenizer.decode(token_id),\n", " end=\"\",\n", " flush=True\n", - " )" + " )\n", + "\n", + "if torch.cuda.is_available():\n", + " def gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + " \n", + " print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")" ] }, { @@ -1171,7 +1196,6 @@ "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" }, "source": [ - "- Check out the [README.md](./README.md), to use this model via the `llms_from_scratch` package\n", "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", "\n", "" @@ -1199,7 +1223,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb b/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb new file mode 100644 index 0000000..3296fef --- /dev/null +++ b/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb @@ -0,0 +1,1290 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": { + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": { + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" + }, + "source": [ + "# Olmo 3 From Scratch (A Standalone Notebook)" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": { + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" + }, + "source": [ + "- This notebook is purposefully minimal and focuses on the code to re-implement Olmo 3 7B and 32 models from Allen AI in pure PyTorch without relying on other external LLM libraries; Olmo 3 is interesting because it is currently the leading fully open-source model\n", + "- For more information, see the official [Olmo 3 announcement](https://allenai.org/blog/olmo3) and model cards:\n", + " - [Olmo-3-1025-7B](https://huggingface.co/allenai/Olmo-3-1025-7B) (base model)\n", + " - [Olmo-3-7B-Instruct](https://huggingface.co/allenai/Olmo-3-7B-Instruct)\n", + " - [Olmo-3-7B-Think](https://huggingface.co/allenai/Olmo-3-7B-Think)\n", + "- Note that there are also 32B versions, which are not listed above for brevity; you can find a complete list [here](https://huggingface.co/collections/allenai/olmo-3-post-training)\n", + "- Below is a side-by-side comparison with Qwen3 8B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n", + "
\n", + "\n", + "\n", + " \n", + "\n", + " \n", + " \n", + "- About the code:\n", + " - all code is my own code, mapping the Olmo 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7c201adb-747e-437b-9a62-442802941e01", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 0.35.0\n", + "tokenizers version: 0.22.1\n", + "torch version: 2.9.1+cu130\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tokenizers\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "07e96fbb-8e16-4f6d-835f-c6159321280b", + "metadata": {}, + "source": [ + "- Note that there are three model types, and each of the four model types comes in a 7B and 32B size:\n", + "1. Base (`Olmo-3-1025-7B` and `Olmo-3-1125-32B`)\n", + "2. Instruct (`Olmo-3-7B/32B-Think`)\n", + "3. Reasoning (`Olmo-3-32B/7B-Think`)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70a90338-624a-4706-aa55-6b4358070194", + "metadata": {}, + "outputs": [], + "source": [ + "# Select which model to use\n", + "\n", + "# USE_MODEL = \"Olmo-3-1025-7B\"\n", + "# USE_MODEL = \"Olmo-3-1125-32B\"\n", + "USE_MODEL = \"Olmo-3-7B-Instruct\"\n", + "# USE_MODEL = \"Olmo-3-32B-Instruct\"\n", + "# USE_MODEL = \"Olmo-3-7B-Think\"\n", + "# USE_MODEL = \"Olmo-3-32B-Think\"\n", + "# USE_MODEL = \"Olmo-3-7B-RLZero-IF\"" + ] + }, + { + "cell_type": "markdown", + "id": "1899ab4b-e1c2-4215-b3d1-ed00d52e4576", + "metadata": {}, + "source": [ + "- In addition to the checkpoints listed above, you can also use the intermediate checkpoints listed [here](https://huggingface.co/collections/allenai/olmo-3-post-training); since they all have the same architecture, they are all compatible with this notebook" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": { + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" + }, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": { + "id": "82076c21-9331-4dcd-b017-42b046cf1a60" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56715760-37e1-433e-89da-04864c139a9e", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " self.weight = nn.Parameter(torch.ones(emb_dim))\n", + "\n", + " def forward(self, x):\n", + " input_dtype = x.dtype\n", + " x_f = x.float()\n", + " var = x_f.pow(2).mean(dim=-1, keepdim=True)\n", + " x_norm = x_f * torch.rsqrt(var + self.eps)\n", + " return (self.weight * x_norm).to(input_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": { + "id": "4b9a346f-5826-4083-9162-abd56afc03f0" + }, + "outputs": [], + "source": [ + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", + " / head_dim\n", + " )\n", + " )\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length, dtype=dtype)\n", + "\n", + " # Optional YaRN scaling\n", + " if rope_type == \"yarn\":\n", + " positions = positions / rope_factor\n", + " positions = torch.clamp(positions, max=rope_orig_max - 1)\n", + "\n", + " # Compute the base angles (shape: [context_length, head_dim // 2])\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", + "\n", + " # Expand to full head_dim (shape: [context_length, head_dim])\n", + " angles = torch.cat([angles, angles], dim=1)\n", + "\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles) * attention_factor\n", + " sin = torch.sin(angles) * attention_factor\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def apply_rope(x, cos, sin, offset=0):\n", + " # x: (batch_size, num_heads, seq_len, head_dim)\n", + " batch_size, num_heads, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " # Split x into first half and second half\n", + " x1 = x[..., : head_dim // 2] # First half\n", + " x2 = x[..., head_dim // 2 :] # Second half\n", + "\n", + " # Adjust sin and cos shapes\n", + " cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", + " sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " # Apply the rotary transformation\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x * cos) + (rotated * sin)\n", + "\n", + " # It's ok to use lower-precision after applying cos and sin rotation\n", + " return x_rotated.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": { + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" + }, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(self, d_in, num_heads, num_kv_groups, head_dim, attention_bias=False, dtype=None, sliding_window=None, attn_type=\"full_attention\"):\n", + " super().__init__()\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " self.head_dim = head_dim\n", + " self.d_out = num_heads * head_dim\n", + " self.attn_type = attn_type\n", + " self.sliding_window = sliding_window if attn_type == \"sliding_attention\" else None\n", + "\n", + " # Projections\n", + " self.W_query = nn.Linear(d_in, self.d_out, bias=attention_bias, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n", + " self.out_proj = nn.Linear(self.d_out, d_in, bias=attention_bias, dtype=dtype)\n", + "\n", + " # Olmo3-style RMSNorm over the flattened projections\n", + " self.q_norm = RMSNorm(self.d_out)\n", + " self.k_norm = RMSNorm(num_kv_groups * head_dim)\n", + "\n", + " def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n", + " b, num_tokens, _ = x.shape\n", + "\n", + " # Apply projections\n", + " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n", + " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + "\n", + " # Normalize q and k\n", + " queries = self.q_norm(queries)\n", + " keys_new = self.k_norm(keys)\n", + "\n", + " # Reshape to (b, heads, seq, head_dim)\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n", + " keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + " values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + "\n", + " # Cache unrotated K/V\n", + " prev_len = 0\n", + " if cache is not None:\n", + " prev_k, prev_v = cache\n", + " if prev_k is not None:\n", + " prev_len = prev_k.size(2)\n", + " keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)\n", + " values_cat_raw = torch.cat([prev_v, values_new], dim=2)\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + " else:\n", + " keys_cat_raw = keys_new\n", + " values_cat_raw = values_new\n", + "\n", + " # Apply RoPE with offsets for cached tokens\n", + " queries = apply_rope(queries, cos, sin, offset=start_pos)\n", + " keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n", + "\n", + " # Expand KV groups to full head count\n", + " if self.group_size > 1:\n", + " keys = keys.repeat_interleave(self.group_size, dim=1)\n", + " values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n", + " else:\n", + " values = values_cat_raw\n", + "\n", + " # Scaling before the matmul seems to be a bit more stable for Olmo\n", + " scale = self.head_dim ** -0.5 # Python float\n", + " queries = queries * scale\n", + "\n", + " # Update cache with unrotated K/V\n", + " if cache is not None and cache[0] is not None:\n", + " next_cache = (\n", + " torch.cat([cache[0], keys_new], dim=2),\n", + " torch.cat([cache[1], values_new], dim=2),\n", + " )\n", + " else:\n", + " next_cache = (keys_new, values_new)\n", + "\n", + " # Attention\n", + " attn_scores = queries @ keys.transpose(2, 3)\n", + " if mask is not None:\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", + "\n", + " attn_weights = torch.softmax(attn_scores, dim=-1)\n", + " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n", + " out = self.out_proj(context)\n", + "\n", + " return out, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "13eb3430-0c06-4fe2-a005-217205eee21e", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg, attn_type):\n", + " super().__init__()\n", + " self.attn_type = attn_type\n", + " self.sliding_window = cfg[\"sliding_window\"]\n", + " self.att = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " num_kv_groups=cfg[\"n_kv_heads\"],\n", + " head_dim=cfg[\"head_dim\"],\n", + " attention_bias=cfg[\"attention_bias\"],\n", + " dtype=cfg[\"dtype\"],\n", + " sliding_window=cfg[\"sliding_window\"],\n", + " attn_type=attn_type,\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + " self.post_attention_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + " self.post_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + "\n", + " def forward(self, x, mask_global, mask_local, cos, sin, start_pos=0, cache=None):\n", + " shortcut = x\n", + " if self.attn_type == \"sliding_attention\":\n", + " if cache is not None and isinstance(cache, tuple):\n", + " prev_k, _ = cache\n", + " prev_len = prev_k.size(2) if prev_k is not None else 0\n", + " else:\n", + " prev_len = 0\n", + " eff_kv_len = prev_len + x.size(1)\n", + " attn_mask = mask_local[..., -eff_kv_len:]\n", + " else:\n", + " attn_mask = mask_global\n", + "\n", + " x_attn, next_cache = self.att(x, attn_mask, cos, sin, start_pos=start_pos, cache=cache)\n", + " if next_cache is not None and self.attn_type == \"sliding_attention\":\n", + " k, v = next_cache\n", + " if k.size(2) > self.sliding_window:\n", + " k = k[:, :, -self.sliding_window:, :]\n", + " v = v[:, :, -self.sliding_window:, :]\n", + " next_cache = (k, v)\n", + "\n", + " x_attn = self.post_attention_layernorm(x_attn)\n", + " x = shortcut + x_attn\n", + "\n", + " shortcut = x\n", + " x_ffn = self.ff(x)\n", + " x_ffn = self.post_feedforward_layernorm(x_ffn)\n", + " x = shortcut + x_ffn\n", + " return x, next_cache" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": { + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" + }, + "outputs": [], + "source": [ + "class Olmo3Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " assert cfg[\"layer_types\"] is not None and len(cfg[\"layer_types\"]) == cfg[\"n_layers\"]\n", + "\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + " self.blocks = nn.ModuleList([TransformerBlock(cfg, attn_type) for attn_type in cfg[\"layer_types\"]])\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.cfg = cfg\n", + " self.current_pos = 0\n", + "\n", + " cos, sin = compute_rope_params(\n", + " head_dim=cfg[\"head_dim\"],\n", + " context_length=cfg[\"context_length\"],\n", + " theta_base=cfg[\"rope_base\"],\n", + " attention_factor=cfg[\"rope_attention_factor\"],\n", + " rope_type=cfg[\"rope_type\"],\n", + " rope_factor=cfg[\"rope_factor\"],\n", + " rope_orig_max=cfg[\"rope_orig_max\"],\n", + " dtype=torch.float32,\n", + " )\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + "\n", + " def create_masks(self, cur_len, device, pos_start=0, pos_end=None):\n", + " if pos_end is None:\n", + " pos_end = cur_len\n", + " total_len = pos_end\n", + "\n", + " ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)\n", + " # mask_global_full (future is masked: j > i)\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 0 0 0 0 0 1 1 1\n", + " # 5: 0 0 0 0 0 0 1 1\n", + " # 6: 0 0 0 0 0 0 0 1\n", + " # 7: 0 0 0 0 0 0 0 0\n", + " mask_global_full = torch.triu(ones, diagonal=1)\n", + "\n", + " # far_past (too far back is masked: i - j >= sliding_window)\n", + " # where sliding_window = 4\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 0 0 0 0 0 0 0\n", + " # 1: 0 0 0 0 0 0 0 0\n", + " # 2: 0 0 0 0 0 0 0 0\n", + " # 3: 0 0 0 0 0 0 0 0\n", + " # 4: 1 0 0 0 0 0 0 0\n", + " # 5: 1 1 0 0 0 0 0 0\n", + " # 6: 1 1 1 0 0 0 0 0\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " far_past_full = torch.triu(ones, diagonal=self.cfg[\"sliding_window\"]).T\n", + "\n", + " # Local (sliding_window) = future OR far-past\n", + " # mask_local\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 1 0 0 0 0 1 1 1\n", + " # 5: 1 1 0 0 0 0 1 1\n", + " # 6: 1 1 1 0 0 0 0 1\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " mask_local_full = mask_global_full | far_past_full\n", + "\n", + " row_slice = slice(pos_start, pos_end)\n", + " mask_global = mask_global_full[row_slice, :pos_end][None, None, :, :]\n", + " mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]\n", + " return mask_global, mask_local\n", + "\n", + " def forward(self, input_ids, cache=None):\n", + " b, seq_len = input_ids.shape\n", + " x = self.tok_emb(input_ids)\n", + "\n", + " if cache is not None:\n", + " pos_start = self.current_pos\n", + " pos_end = pos_start + seq_len\n", + " self.current_pos = pos_end\n", + " mask_global, mask_local = self.create_masks(\n", + " cur_len=seq_len, device=x.device, pos_start=pos_start, pos_end=pos_end\n", + " )\n", + " else:\n", + " pos_start = 0\n", + " mask_global, mask_local = self.create_masks(\n", + " cur_len=seq_len, device=x.device, pos_start=0, pos_end=seq_len\n", + " )\n", + "\n", + " cos = self.cos\n", + " sin = self.sin\n", + "\n", + " for i, block in enumerate(self.blocks):\n", + " blk_cache = cache.get(i) if cache is not None else None\n", + " x, new_blk_cache = block(\n", + " x,\n", + " mask_global=mask_global,\n", + " mask_local=mask_local,\n", + " cos=cos,\n", + " sin=sin,\n", + " start_pos=pos_start,\n", + " cache=blk_cache,\n", + " )\n", + "\n", + " if cache is not None:\n", + " cache.update(i, new_blk_cache)\n", + "\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", + " return logits\n", + "\n", + " def reset_kv_cache(self):\n", + " self.current_pos = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4f5271e8-ff28-4aaa-bbb2-f73582e6d228", + "metadata": {}, + "outputs": [], + "source": [ + "class KVCache:\n", + " def __init__(self, n_layers):\n", + " self.cache = [None] * n_layers\n", + "\n", + " def get(self, layer_idx):\n", + " return self.cache[layer_idx]\n", + "\n", + " def update(self, layer_idx, value):\n", + " self.cache[layer_idx] = value\n", + "\n", + " def get_all(self):\n", + " return self.cache\n", + "\n", + " def reset(self):\n", + " for i in range(len(self.cache)):\n", + " self.cache[i] = None" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": { + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" + }, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": { + "id": "caa142fa-b375-4e78-b392-2072ced666f3" + }, + "outputs": [], + "source": [ + "OLMO3_CONFIG_7B = {\n", + " \"vocab_size\": 100_278,\n", + " \"context_length\": 65_536,\n", + " \"emb_dim\": 4_096,\n", + " \"n_heads\": 32,\n", + " \"n_layers\": 32,\n", + " \"hidden_dim\": 11_008,\n", + " \"head_dim\": 128,\n", + " \"n_kv_heads\": 32,\n", + " \"attention_bias\": False,\n", + " \"attention_dropout\": 0.0,\n", + " \"sliding_window\": 4_096,\n", + " \"layer_types\": [\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " ],\n", + " \"rope_base\": 500_000.0,\n", + " \"rope_attention_factor\": 1.2079441541679836,\n", + " \"rope_type\": \"yarn\",\n", + " \"rope_factor\": 8.0,\n", + " \"rope_orig_max\": 8_192,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"dtype\": torch.bfloat16,\n", + " \"eos_token_id\": 100_257,\n", + " \"pad_token_id\": 100_277,\n", + "}\n", + "\n", + "OLMO3_CONFIG_32B = {\n", + " \"vocab_size\": 100_278,\n", + " \"context_length\": 65_536,\n", + " \"emb_dim\": 5_120,\n", + " \"n_heads\": 40,\n", + " \"n_layers\": 64,\n", + " \"hidden_dim\": 27_648,\n", + " \"head_dim\": 128,\n", + " \"n_kv_heads\": 8,\n", + " \"attention_bias\": False,\n", + " \"attention_dropout\": 0.0,\n", + " \"sliding_window\": 4_096,\n", + " \"layer_types\": [\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " ],\n", + " \"rope_base\": 500_000.0,\n", + " \"rope_attention_factor\": 1.2079441541679836,\n", + " \"rope_type\": \"yarn\",\n", + " \"rope_factor\": 8.0,\n", + " \"rope_orig_max\": 8_192,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"dtype\": torch.bfloat16,\n", + " \"eos_token_id\": 100_257,\n", + " \"pad_token_id\": 100_277,\n", + "}\n", + "\n", + "OLMO3_CONFIG = OLMO3_CONFIG_32B if \"32B\" in USE_MODEL else OLMO3_CONFIG_7B" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": { + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" + }, + "outputs": [], + "source": [ + "torch.manual_seed(123)\n", + "model = Olmo3Model(OLMO3_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "eaf86265-4e9d-4024-9ed0-99076944e304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Olmo3Model(\n", + " (tok_emb): Embedding(100278, 4096)\n", + " (blocks): ModuleList(\n", + " (0-31): 32 x TransformerBlock(\n", + " (att): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (W_value): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc3): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", + " (post_attention_layernorm): RMSNorm()\n", + " (post_feedforward_layernorm): RMSNorm()\n", + " )\n", + " )\n", + " (final_norm): RMSNorm()\n", + " (out_head): Linear(in_features=4096, out_features=100278, bias=False)\n", + ")" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b", + "metadata": {}, + "source": [ + "- A quick check that the forward pass works before continuing:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n", + " [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n", + " [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n", + " dtype=torch.bfloat16, grad_fn=)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(torch.tensor([1, 2, 3]).unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": { + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n", + " Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n", + " Minimum and Maximum cuda capability supported by this version of PyTorch is\n", + " (8.0) - (12.0)\n", + " \n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": { + "id": "c172f89f-d301-439f-b809-46169e5f5945" + }, + "source": [ + " \n", + "# 4. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": { + "id": "75166128-5899-4995-9b88-9672e135650e" + }, + "outputs": [], + "source": [ + "def load_weights_into_olmo(model, param_config, params):\n", + " def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(\n", + " f\"Shape mismatch in tensor '{tensor_name}'. \"\n", + " f\"Left: {left.shape}, Right: {right.shape}\"\n", + " )\n", + " \n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", + " \n", + " return left\n", + "\n", + " # Token embedding\n", + " if \"model.embed_tokens.weight\" in params:\n", + " model.tok_emb.weight = assign(\n", + " model.tok_emb.weight,\n", + " params[\"model.embed_tokens.weight\"],\n", + " \"model.embed_tokens.weight\",\n", + " )\n", + "\n", + " for l in range(param_config[\"n_layers\"]):\n", + " block = model.blocks[l]\n", + " att = block.att\n", + "\n", + " # Q, K, V projections\n", + " att.W_query.weight = assign(\n", + " att.W_query.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_proj.weight\",\n", + " )\n", + " att.W_key.weight = assign(\n", + " att.W_key.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_proj.weight\",\n", + " )\n", + " att.W_value.weight = assign(\n", + " att.W_value.weight,\n", + " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.v_proj.weight\",\n", + " )\n", + "\n", + " # Output projection\n", + " att.out_proj.weight = assign(\n", + " att.out_proj.weight,\n", + " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.o_proj.weight\",\n", + " )\n", + "\n", + " # QK norms\n", + " att.q_norm.weight = assign(\n", + " att.q_norm.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_norm.weight\",\n", + " )\n", + " att.k_norm.weight = assign(\n", + " att.k_norm.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_norm.weight\",\n", + " )\n", + "\n", + " # Feedforward weights\n", + " block.ff.fc1.weight = assign(\n", + " block.ff.fc1.weight,\n", + " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.gate_proj.weight\",\n", + " )\n", + " block.ff.fc2.weight = assign(\n", + " block.ff.fc2.weight,\n", + " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.up_proj.weight\",\n", + " )\n", + " block.ff.fc3.weight = assign(\n", + " block.ff.fc3.weight,\n", + " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.down_proj.weight\",\n", + " )\n", + "\n", + " # Post-attention and post norms\n", + " block.post_attention_layernorm.weight = assign(\n", + " block.post_attention_layernorm.weight,\n", + " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_attention_layernorm.weight\",\n", + " )\n", + " block.post_feedforward_layernorm.weight = assign(\n", + " block.post_feedforward_layernorm.weight,\n", + " params[f\"model.layers.{l}.post_feedforward_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_feedforward_layernorm.weight\",\n", + " )\n", + "\n", + " # Final normalization and output head\n", + " if \"model.norm.weight\" in params:\n", + " model.final_norm.weight = assign(\n", + " model.final_norm.weight,\n", + " params[\"model.norm.weight\"],\n", + " \"model.norm.weight\",\n", + " )\n", + "\n", + " if \"lm_head.weight\" in params:\n", + " model.out_head.weight = assign(\n", + " model.out_head.weight,\n", + " params[\"lm_head.weight\"],\n", + " \"lm_head.weight\",\n", + " )\n", + " else:\n", + " model.out_head.weight = model.tok_emb.weight\n", + " print(\"Model uses weight tying.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "9881b6995c3f49dc89e6992fd9ab660b", + "17a3174e65c54476b2e0d1faf8f011ca", + "1bbf2e62c0754d1593beb4105a7f1ac1", + "b82112e1dec645d98aa1c1ba64abcb61", + "271e2bd6a35e4a8b92de8697f7c0be5f", + "90a79523187446dfa692723b2e5833a7", + "431ffb83b8c14bf182f0430e07ea6154", + "a8f1b72a33dd4b548de23fbd95e0da18", + "25cc36132d384189acfbecc59483134b", + "bfd06423ad544218968648016e731a46", + "d029630b63ff44cf807ade428d2eb421" + ] + }, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0fcdf72bf5b646d39bf4ed84faeb3302", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 14 files: 0%| | 0/14 [00:00\")\n", + " or self._tok.token_to_id(\"\")\n", + " )\n", + " self.eos_token_id = eos_from_tok if eos_from_tok is not None else eos_token_id\n", + " pad_from_tok = (\n", + " self._tok.token_to_id(\"<|pad|>\")\n", + " or self._tok.token_to_id(\"\")\n", + " )\n", + " self.pad_token_id = pad_from_tok if pad_from_tok is not None else pad_token_id\n", + "\n", + " def encode(self, text):\n", + " return self._tok.encode(text).ids\n", + "\n", + " def decode(self, ids):\n", + " return self._tok.decode(ids, skip_special_tokens=False)\n", + "\n", + "\n", + "def apply_chat_template(user_text):\n", + " return (\n", + " \"<|im_start|>user\\n\"\n", + " f\"{user_text}\\n\"\n", + " \"<|im_end|>\\n\"\n", + " \"<|im_start|>assistant\\n\"\n", + " )\n", + "\n", + "\n", + "tokenizer_file_path = os.path.join(local_dir, \"tokenizer.json\")\n", + "if not os.path.exists(tokenizer_file_path):\n", + " try:\n", + " tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename=\"tokenizer.json\", local_dir=local_dir)\n", + " except Exception as e:\n", + " print(f\"Warning: failed to download tokenizer.json: {e}\")\n", + " tokenizer_file_path = \"tokenizer.json\"\n", + "\n", + "tokenizer = OlmoTokenizer(\n", + " tokenizer_file_path=tokenizer_file_path,\n", + " eos_token_id=OLMO3_CONFIG[\"eos_token_id\"],\n", + " pad_token_id=OLMO3_CONFIG[\"pad_token_id\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = apply_chat_template(\"Give me a short intro to large language models in 3 sentences.\")\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "text = tokenizer.decode(input_token_ids)\n", + "text" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": { + "id": "57d07df1-4401-4792-b549-7c4cc5632323" + }, + "source": [ + " \n", + "# 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": { + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" + }, + "outputs": [], + "source": [ + "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n", + " model.reset_kv_cache()\n", + "\n", + " logits = model(token_ids, cache=cache)\n", + "\n", + " for _ in range(max_new_tokens):\n", + " next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n", + "\n", + " if (eos_token_id is not None\n", + " and torch.all(next_token == eos_token_id)):\n", + " break\n", + "\n", + " yield next_token\n", + "\n", + " token_ids = torch.cat([token_ids, next_token], dim=1)\n", + "\n", + " logits = model(next_token, cache=cache)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", + "metadata": { + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sure! Here’s a brief introduction to large language models: \n", + "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.\n", + "\n", + "GPU memory used: 13.71 GB\n" + ] + } + ], + "source": [ + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "if torch.cuda.is_available():\n", + " def gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + " \n", + " print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": { + "id": "549324d6-5c71-4147-ae21-2e67675faa3d" + }, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": { + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" + }, + "source": [ + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/13_olmo3/standalone-olmo3.ipynb b/ch05/13_olmo3/standalone-olmo3.ipynb new file mode 100644 index 0000000..2c7f42f --- /dev/null +++ b/ch05/13_olmo3/standalone-olmo3.ipynb @@ -0,0 +1,1183 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": { + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": { + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" + }, + "source": [ + "# Olmo 3 From Scratch (A Standalone Notebook)" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": { + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" + }, + "source": [ + "- This notebook is purposefully minimal and focuses on the code to re-implement Olmo 3 7B and 32 models from Allen AI in pure PyTorch without relying on other external LLM libraries; Olmo 3 is interesting because it is currently the leading fully open-source model\n", + "- For more information, see the official [Olmo 3 announcement](https://allenai.org/blog/olmo3) and model cards:\n", + " - [Olmo-3-1025-7B](https://huggingface.co/allenai/Olmo-3-1025-7B) (base model)\n", + " - [Olmo-3-7B-Instruct](https://huggingface.co/allenai/Olmo-3-7B-Instruct)\n", + " - [Olmo-3-7B-Think](https://huggingface.co/allenai/Olmo-3-7B-Think)\n", + "- Note that there are also 32B versions, which are not listed above for brevity; you can find a complete list [here](https://huggingface.co/collections/allenai/olmo-3-post-training)\n", + "- Below is a side-by-side comparison with Qwen3 8B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n", + "
\n", + "\n", + "\n", + " \n", + "\n", + " \n", + " \n", + "- About the code:\n", + " - all code is my own code, mapping the Olmo 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7c201adb-747e-437b-9a62-442802941e01", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface_hub version: 0.35.0\n", + "tokenizers version: 0.22.1\n", + "torch version: 2.9.1+cu130\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tokenizers\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "07e96fbb-8e16-4f6d-835f-c6159321280b", + "metadata": {}, + "source": [ + "- Note that there are three model types, and each of the four model types comes in a 7B and 32B size:\n", + "1. Base (`Olmo-3-1025-7B` and `Olmo-3-1125-32B`)\n", + "2. Instruct (`Olmo-3-7B/32B-Think`)\n", + "3. Reasoning (`Olmo-3-32B/7B-Think`)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "70a90338-624a-4706-aa55-6b4358070194", + "metadata": {}, + "outputs": [], + "source": [ + "# Select which model to use\n", + "\n", + "# USE_MODEL = \"Olmo-3-1025-7B\"\n", + "# USE_MODEL = \"Olmo-3-1125-32B\"\n", + "USE_MODEL = \"Olmo-3-7B-Instruct\"\n", + "# USE_MODEL = \"Olmo-3-32B-Instruct\"\n", + "# USE_MODEL = \"Olmo-3-7B-Think\"\n", + "# USE_MODEL = \"Olmo-3-32B-Think\"\n", + "# USE_MODEL = \"Olmo-3-7B-RLZero-IF\"" + ] + }, + { + "cell_type": "markdown", + "id": "f258cb74-1c4e-4880-8772-3c85fb920811", + "metadata": {}, + "source": [ + "- In addition to the checkpoints listed above, you can also use the intermediate checkpoints listed [here](https://huggingface.co/collections/allenai/olmo-3-post-training); since they all have the same architecture, they are all compatible with this notebook" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": { + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" + }, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": { + "id": "82076c21-9331-4dcd-b017-42b046cf1a60" + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "56715760-37e1-433e-89da-04864c139a9e", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, emb_dim, eps=1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " self.weight = nn.Parameter(torch.ones(emb_dim))\n", + "\n", + " def forward(self, x):\n", + " input_dtype = x.dtype\n", + " x_f = x.float()\n", + " var = x_f.pow(2).mean(dim=-1, keepdim=True)\n", + " x_norm = x_f * torch.rsqrt(var + self.eps)\n", + " return (self.weight * x_norm).to(input_dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": { + "id": "4b9a346f-5826-4083-9162-abd56afc03f0" + }, + "outputs": [], + "source": [ + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", + " / head_dim\n", + " )\n", + " )\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length, dtype=dtype)\n", + "\n", + " # Optional YaRN scaling\n", + " if rope_type == \"yarn\":\n", + " positions = positions / rope_factor\n", + " positions = torch.clamp(positions, max=rope_orig_max - 1)\n", + "\n", + " # Compute the base angles (shape: [context_length, head_dim // 2])\n", + " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", + "\n", + " # Expand to full head_dim (shape: [context_length, head_dim])\n", + " angles = torch.cat([angles, angles], dim=1)\n", + "\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles) * attention_factor\n", + " sin = torch.sin(angles) * attention_factor\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def apply_rope(x, cos, sin):\n", + " # x: (batch_size, num_heads, seq_len, head_dim)\n", + " batch_size, num_heads, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " # Split x into first half and second half\n", + " x1 = x[..., : head_dim // 2] # First half\n", + " x2 = x[..., head_dim // 2 :] # Second half\n", + "\n", + " # Adjust sin and cos shapes\n", + " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", + " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " # Apply the rotary transformation\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x * cos) + (rotated * sin)\n", + "\n", + " # It's ok to use lower-precision after applying cos and sin rotation\n", + " return x_rotated.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": { + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" + }, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(self, d_in, num_heads, num_kv_groups, head_dim, attention_bias=False, dtype=None, sliding_window=None, attn_type=\"full_attention\"):\n", + " super().__init__()\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " self.head_dim = head_dim\n", + " self.d_out = num_heads * head_dim\n", + " self.attn_type = attn_type\n", + " self.sliding_window = sliding_window if attn_type == \"sliding_attention\" else None\n", + "\n", + " # Projections\n", + " self.W_query = nn.Linear(d_in, self.d_out, bias=attention_bias, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n", + " self.out_proj = nn.Linear(self.d_out, d_in, bias=attention_bias, dtype=dtype)\n", + "\n", + " # Olmo3-style RMSNorm over the flattened projections\n", + " self.q_norm = RMSNorm(self.d_out)\n", + " self.k_norm = RMSNorm(num_kv_groups * head_dim)\n", + "\n", + " def forward(self, x, mask, cos, sin):\n", + " b, num_tokens, _ = x.shape\n", + "\n", + " # Apply projections\n", + " queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n", + " keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + " values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n", + "\n", + " # Normalize q and k\n", + " queries = self.q_norm(queries)\n", + " keys = self.k_norm(keys)\n", + "\n", + " # Reshape to (b, heads, seq, head_dim)\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n", + " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n", + "\n", + " # Apply RoPE\n", + " queries = apply_rope(queries, cos, sin)\n", + " keys = apply_rope(keys, cos, sin)\n", + "\n", + " # Expand KV groups to full head count\n", + " if self.group_size > 1:\n", + " keys = keys.repeat_interleave(self.group_size, dim=1)\n", + " values = values.repeat_interleave(self.group_size, dim=1)\n", + "\n", + " # Scaling before the matmul seems to be a bit more stable for Olmo\n", + " scale = self.head_dim ** -0.5 # Python float\n", + " queries = queries * scale\n", + " \n", + " # Attention\n", + " attn_scores = queries @ keys.transpose(2, 3)\n", + " if mask is not None:\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", + "\n", + " attn_weights = torch.softmax(attn_scores, dim=-1)\n", + " context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n", + " return self.out_proj(context)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": { + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" + }, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg, attn_type):\n", + " super().__init__()\n", + " self.attn_type = attn_type\n", + " self.att = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " num_kv_groups=cfg[\"n_kv_heads\"],\n", + " head_dim=cfg[\"head_dim\"],\n", + " attention_bias=cfg[\"attention_bias\"],\n", + " dtype=cfg[\"dtype\"],\n", + " sliding_window=cfg[\"sliding_window\"],\n", + " attn_type=attn_type,\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + " self.post_attention_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + " self.post_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + "\n", + " def forward(self, x, mask_global, mask_local, cos, sin):\n", + " attn_mask = mask_local if self.attn_type == \"sliding_attention\" else mask_global\n", + "\n", + " shortcut = x\n", + " x_attn = self.att(x, attn_mask, cos, sin)\n", + " x_attn = self.post_attention_layernorm(x_attn)\n", + " x = shortcut + x_attn\n", + "\n", + " shortcut = x\n", + " x_ffn = self.ff(x)\n", + " x_ffn = self.post_feedforward_layernorm(x_ffn)\n", + " x = shortcut + x_ffn\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", + "metadata": { + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" + }, + "outputs": [], + "source": [ + "class Olmo3Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " assert cfg[\"layer_types\"] is not None and len(cfg[\"layer_types\"]) == cfg[\"n_layers\"]\n", + "\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + " self.blocks = nn.ModuleList([TransformerBlock(cfg, attn_type) for attn_type in cfg[\"layer_types\"]])\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + " self.cfg = cfg\n", + "\n", + " cos, sin = compute_rope_params(\n", + " head_dim=cfg[\"head_dim\"],\n", + " context_length=cfg[\"context_length\"],\n", + " theta_base=cfg[\"rope_base\"],\n", + " attention_factor=cfg[\"rope_attention_factor\"],\n", + " rope_type=cfg[\"rope_type\"],\n", + " rope_factor=cfg[\"rope_factor\"],\n", + " rope_orig_max=cfg[\"rope_orig_max\"],\n", + " dtype=torch.float32,\n", + " )\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + "\n", + " def create_masks(self, seq_len, device):\n", + " ones = torch.ones((seq_len, seq_len), dtype=torch.bool, device=device)\n", + "\n", + " # mask_global (future is masked: j > i)\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 0 0 0 0 0 1 1 1\n", + " # 5: 0 0 0 0 0 0 1 1\n", + " # 6: 0 0 0 0 0 0 0 1\n", + " # 7: 0 0 0 0 0 0 0 0\n", + " mask_global = torch.triu(ones, diagonal=1)\n", + "\n", + " # far_past (too far back is masked: i - j >= sliding_window)\n", + " # where sliding_window = 4\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 0 0 0 0 0 0 0\n", + " # 1: 0 0 0 0 0 0 0 0\n", + " # 2: 0 0 0 0 0 0 0 0\n", + " # 3: 0 0 0 0 0 0 0 0\n", + " # 4: 1 0 0 0 0 0 0 0\n", + " # 5: 1 1 0 0 0 0 0 0\n", + " # 6: 1 1 1 0 0 0 0 0\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " far_past = torch.triu(ones, diagonal=self.cfg[\"sliding_window\"]).T\n", + "\n", + " # Local (sliding_window) = future OR far-past\n", + " # mask_local\n", + " # j: 0 1 2 3 4 5 6 7\n", + " # i\n", + " # 0: 0 1 1 1 1 1 1 1\n", + " # 1: 0 0 1 1 1 1 1 1\n", + " # 2: 0 0 0 1 1 1 1 1\n", + " # 3: 0 0 0 0 1 1 1 1\n", + " # 4: 1 0 0 0 0 1 1 1\n", + " # 5: 1 1 0 0 0 0 1 1\n", + " # 6: 1 1 1 0 0 0 0 1\n", + " # 7: 1 1 1 1 0 0 0 0\n", + " mask_local = mask_global | far_past\n", + " return mask_global, mask_local\n", + "\n", + " def forward(self, input_ids):\n", + " b, seq_len = input_ids.shape\n", + " x = self.tok_emb(input_ids)\n", + " mask_global, mask_local = self.create_masks(seq_len, x.device)\n", + "\n", + " cos = self.cos[:seq_len, :].to(x.device)\n", + " sin = self.sin[:seq_len, :].to(x.device)\n", + "\n", + " for block in self.blocks:\n", + " x = block(x, mask_global, mask_local, cos, sin)\n", + "\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": { + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" + }, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": { + "id": "caa142fa-b375-4e78-b392-2072ced666f3" + }, + "outputs": [], + "source": [ + "OLMO3_CONFIG_7B = {\n", + " \"vocab_size\": 100_278,\n", + " \"context_length\": 65_536,\n", + " \"emb_dim\": 4_096,\n", + " \"n_heads\": 32,\n", + " \"n_layers\": 32,\n", + " \"hidden_dim\": 11_008,\n", + " \"head_dim\": 128,\n", + " \"n_kv_heads\": 32,\n", + " \"attention_bias\": False,\n", + " \"attention_dropout\": 0.0,\n", + " \"sliding_window\": 4_096,\n", + " \"layer_types\": [\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " ],\n", + " \"rope_base\": 500_000.0,\n", + " \"rope_attention_factor\": 1.2079441541679836,\n", + " \"rope_type\": \"yarn\",\n", + " \"rope_factor\": 8.0,\n", + " \"rope_orig_max\": 8_192,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"dtype\": torch.bfloat16,\n", + " \"eos_token_id\": 100_257,\n", + " \"pad_token_id\": 100_277,\n", + "}\n", + "\n", + "OLMO3_CONFIG_32B = {\n", + " \"vocab_size\": 100_278,\n", + " \"context_length\": 65_536,\n", + " \"emb_dim\": 5_120,\n", + " \"n_heads\": 40,\n", + " \"n_layers\": 64,\n", + " \"hidden_dim\": 27_648,\n", + " \"head_dim\": 128,\n", + " \"n_kv_heads\": 8,\n", + " \"attention_bias\": False,\n", + " \"attention_dropout\": 0.0,\n", + " \"sliding_window\": 4_096,\n", + " \"layer_types\": [\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"sliding_attention\",\n", + " \"full_attention\",\n", + " ],\n", + " \"rope_base\": 500_000.0,\n", + " \"rope_attention_factor\": 1.2079441541679836,\n", + " \"rope_type\": \"yarn\",\n", + " \"rope_factor\": 8.0,\n", + " \"rope_orig_max\": 8_192,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"dtype\": torch.bfloat16,\n", + " \"eos_token_id\": 100_257,\n", + " \"pad_token_id\": 100_277,\n", + "}\n", + "\n", + "OLMO3_CONFIG = OLMO3_CONFIG_32B if \"32B\" in USE_MODEL else OLMO3_CONFIG_7B" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": { + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" + }, + "outputs": [], + "source": [ + "torch.manual_seed(123)\n", + "model = Olmo3Model(OLMO3_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eaf86265-4e9d-4024-9ed0-99076944e304", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Olmo3Model(\n", + " (tok_emb): Embedding(100278, 4096)\n", + " (blocks): ModuleList(\n", + " (0-31): 32 x TransformerBlock(\n", + " (att): GroupedQueryAttention(\n", + " (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (W_key): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (W_value): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n", + " (q_norm): RMSNorm()\n", + " (k_norm): RMSNorm()\n", + " )\n", + " (ff): FeedForward(\n", + " (fc1): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc2): Linear(in_features=4096, out_features=11008, bias=False)\n", + " (fc3): Linear(in_features=11008, out_features=4096, bias=False)\n", + " )\n", + " (post_attention_layernorm): RMSNorm()\n", + " (post_feedforward_layernorm): RMSNorm()\n", + " )\n", + " )\n", + " (final_norm): RMSNorm()\n", + " (out_head): Linear(in_features=4096, out_features=100278, bias=False)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "90aca91d-4bee-45ce-993a-4ec5393abe2b", + "metadata": {}, + "source": [ + "- A quick check that the forward pass works before continuing:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "adf0a6b7-b688-42c9-966e-c223d34db99f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n", + " [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n", + " [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n", + " dtype=torch.bfloat16, grad_fn=)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model(torch.tensor([1, 2, 3]).unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": { + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n", + " Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n", + " Minimum and Maximum cuda capability supported by this version of PyTorch is\n", + " (8.0) - (12.0)\n", + " \n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": { + "id": "c172f89f-d301-439f-b809-46169e5f5945" + }, + "source": [ + " \n", + "# 4. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": { + "id": "75166128-5899-4995-9b88-9672e135650e" + }, + "outputs": [], + "source": [ + "def load_weights_into_olmo(model, param_config, params):\n", + " def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(\n", + " f\"Shape mismatch in tensor '{tensor_name}'. \"\n", + " f\"Left: {left.shape}, Right: {right.shape}\"\n", + " )\n", + " \n", + " with torch.no_grad():\n", + " if isinstance(right, torch.Tensor):\n", + " left.copy_(right)\n", + " else:\n", + " left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n", + " \n", + " return left\n", + "\n", + " # Token embedding\n", + " if \"model.embed_tokens.weight\" in params:\n", + " model.tok_emb.weight = assign(\n", + " model.tok_emb.weight,\n", + " params[\"model.embed_tokens.weight\"],\n", + " \"model.embed_tokens.weight\",\n", + " )\n", + "\n", + " for l in range(param_config[\"n_layers\"]):\n", + " block = model.blocks[l]\n", + " att = block.att\n", + "\n", + " # Q, K, V projections\n", + " att.W_query.weight = assign(\n", + " att.W_query.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_proj.weight\",\n", + " )\n", + " att.W_key.weight = assign(\n", + " att.W_key.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_proj.weight\",\n", + " )\n", + " att.W_value.weight = assign(\n", + " att.W_value.weight,\n", + " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.v_proj.weight\",\n", + " )\n", + "\n", + " # Output projection\n", + " att.out_proj.weight = assign(\n", + " att.out_proj.weight,\n", + " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.o_proj.weight\",\n", + " )\n", + "\n", + " # QK norms\n", + " att.q_norm.weight = assign(\n", + " att.q_norm.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_norm.weight\",\n", + " )\n", + " att.k_norm.weight = assign(\n", + " att.k_norm.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_norm.weight\",\n", + " )\n", + "\n", + " # Feedforward weights\n", + " block.ff.fc1.weight = assign(\n", + " block.ff.fc1.weight,\n", + " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.gate_proj.weight\",\n", + " )\n", + " block.ff.fc2.weight = assign(\n", + " block.ff.fc2.weight,\n", + " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.up_proj.weight\",\n", + " )\n", + " block.ff.fc3.weight = assign(\n", + " block.ff.fc3.weight,\n", + " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.down_proj.weight\",\n", + " )\n", + "\n", + " # Post-attention and post norms\n", + " block.post_attention_layernorm.weight = assign(\n", + " block.post_attention_layernorm.weight,\n", + " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_attention_layernorm.weight\",\n", + " )\n", + " block.post_feedforward_layernorm.weight = assign(\n", + " block.post_feedforward_layernorm.weight,\n", + " params[f\"model.layers.{l}.post_feedforward_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_feedforward_layernorm.weight\",\n", + " )\n", + "\n", + " # Final normalization and output head\n", + " if \"model.norm.weight\" in params:\n", + " model.final_norm.weight = assign(\n", + " model.final_norm.weight,\n", + " params[\"model.norm.weight\"],\n", + " \"model.norm.weight\",\n", + " )\n", + "\n", + " if \"lm_head.weight\" in params:\n", + " model.out_head.weight = assign(\n", + " model.out_head.weight,\n", + " params[\"lm_head.weight\"],\n", + " \"lm_head.weight\",\n", + " )\n", + " else:\n", + " model.out_head.weight = model.tok_emb.weight\n", + " print(\"Model uses weight tying.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17, + "referenced_widgets": [ + "9881b6995c3f49dc89e6992fd9ab660b", + "17a3174e65c54476b2e0d1faf8f011ca", + "1bbf2e62c0754d1593beb4105a7f1ac1", + "b82112e1dec645d98aa1c1ba64abcb61", + "271e2bd6a35e4a8b92de8697f7c0be5f", + "90a79523187446dfa692723b2e5833a7", + "431ffb83b8c14bf182f0430e07ea6154", + "a8f1b72a33dd4b548de23fbd95e0da18", + "25cc36132d384189acfbecc59483134b", + "bfd06423ad544218968648016e731a46", + "d029630b63ff44cf807ade428d2eb421" + ] + }, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ac81f5bc2063498b98e2c8956f0598be", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 14 files: 0%| | 0/14 [00:00\")\n", + " or self._tok.token_to_id(\"\")\n", + " )\n", + " self.eos_token_id = eos_from_tok if eos_from_tok is not None else eos_token_id\n", + " pad_from_tok = (\n", + " self._tok.token_to_id(\"<|pad|>\")\n", + " or self._tok.token_to_id(\"\")\n", + " )\n", + " self.pad_token_id = pad_from_tok if pad_from_tok is not None else pad_token_id\n", + "\n", + " def encode(self, text):\n", + " return self._tok.encode(text).ids\n", + "\n", + " def decode(self, ids):\n", + " return self._tok.decode(ids, skip_special_tokens=False)\n", + "\n", + "\n", + "def apply_chat_template(user_text):\n", + " return (\n", + " \"<|im_start|>user\\n\"\n", + " f\"{user_text}\\n\"\n", + " \"<|im_end|>\\n\"\n", + " \"<|im_start|>assistant\\n\"\n", + " )\n", + "\n", + "\n", + "tokenizer_file_path = os.path.join(local_dir, \"tokenizer.json\")\n", + "if not os.path.exists(tokenizer_file_path):\n", + " try:\n", + " tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename=\"tokenizer.json\", local_dir=local_dir)\n", + " except Exception as e:\n", + " print(f\"Warning: failed to download tokenizer.json: {e}\")\n", + " tokenizer_file_path = \"tokenizer.json\"\n", + "\n", + "tokenizer = OlmoTokenizer(\n", + " tokenizer_file_path=tokenizer_file_path,\n", + " eos_token_id=OLMO3_CONFIG[\"eos_token_id\"],\n", + " pad_token_id=OLMO3_CONFIG[\"pad_token_id\"],\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = apply_chat_template(\"Give me a short intro to large language models in 3 sentences.\")\n", + "\n", + "input_token_ids = tokenizer.encode(prompt)\n", + "text = tokenizer.decode(input_token_ids)\n", + "text" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": { + "id": "57d07df1-4401-4792-b549-7c4cc5632323" + }, + "source": [ + " \n", + "# 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": { + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" + }, + "outputs": [], + "source": [ + "def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None):\n", + "\n", + " model.eval()\n", + " with torch.no_grad():\n", + " for _ in range(max_new_tokens):\n", + " out = model(token_ids)[:, -1]\n", + " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + "\n", + " if (eos_token_id is not None\n", + " and torch.all(next_token == eos_token_id)):\n", + " break\n", + "\n", + " yield next_token\n", + " \n", + " token_ids = torch.cat([token_ids, next_token], dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", + "metadata": { + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sure! Here’s a brief introduction to large language models: \n", + "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating human-like text, learning from vast data. \n", + "They excel at many tasks across many languages and adapt to various tasks. \n", + "They power modern applications widely used in NLP solutions.\n", + "\n", + "GPU memory used: 13.70 GB\n" + ] + } + ], + "source": [ + "input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n", + "\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.reset_peak_memory_stats()\n", + "\n", + "\n", + "for token in generate_text_basic_stream(\n", + " model=model,\n", + " token_ids=input_token_ids_tensor,\n", + " max_new_tokens=500,\n", + " eos_token_id=tokenizer.eos_token_id\n", + "):\n", + " token_id = token.squeeze(0).tolist()\n", + " print(\n", + " tokenizer.decode(token_id),\n", + " end=\"\",\n", + " flush=True\n", + " )\n", + "\n", + "if torch.cuda.is_available():\n", + " def gpu_gb(x):\n", + " return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n", + " \n", + " print(f\"\\n\\nGPU memory used: {gpu_gb(torch.cuda.max_memory_allocated())}\")" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": { + "id": "549324d6-5c71-4147-ae21-2e67675faa3d" + }, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": { + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" + }, + "source": [ + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch05/13_olmo3/tests/olmo3_layer_debugger.py b/ch05/13_olmo3/tests/olmo3_layer_debugger.py new file mode 100644 index 0000000..fe58ba3 --- /dev/null +++ b/ch05/13_olmo3/tests/olmo3_layer_debugger.py @@ -0,0 +1,240 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + +try: + from transformers import Olmo3Config, Olmo3ForCausalLM +except ImportError: + Olmo3Config = None + Olmo3ForCausalLM = None + + +def tiny_debug_config(): + return { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_heads": 2, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + "attention_bias": False, + "rms_norm_eps": 1e-6, + "rope_base": 1_000_000.0, + "rope_attention_factor": 1.0, + "rope_type": "default", + "rope_factor": 1.0, + "rope_orig_max": 8, + "rope_local_base": 10_000.0, + } + + +def _hf_config_from_dict(cfg): + if Olmo3Config is None: + raise ImportError("transformers is required for the Olmo-3 debugger.") + + return Olmo3Config( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_heads"], + rope_theta=cfg["rope_base"], + rope_local_base_freq=cfg.get("rope_local_base", 10_000.0), + layer_types=cfg["layer_types"], + sliding_window=cfg["sliding_window"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=cfg.get("dtype", torch.float32), + query_pre_attn_scalar=cfg.get("query_pre_attn_scalar", 256), + rope_scaling={"rope_type": cfg.get("rope_type", "default")}, + qk_norm=cfg.get("qk_norm", False), + rms_norm_eps=cfg.get("rms_norm_eps", 1e-5), + ) + + +def load_notebook_defs(nb_name="standalone-olmo3.ipynb"): + nb_dir = Path(__file__).resolve().parents[1] + return import_definitions_from_notebook(nb_dir, nb_name) + + +def build_olmo3_pair(nb_imports, cfg, hf_checkpoint=None): + if Olmo3ForCausalLM is None: + raise ImportError("transformers is required for the Olmo-3 debugger.") + + ours = nb_imports.Olmo3Model(cfg) + hf_cfg = _hf_config_from_dict(cfg) + + if hf_checkpoint: + hf_model = Olmo3ForCausalLM.from_pretrained( + hf_checkpoint, + torch_dtype=cfg.get("dtype", torch.float32), + attn_implementation="eager", + ) + else: + hf_model = Olmo3ForCausalLM(hf_cfg) + + param_config = {"n_layers": cfg["n_layers"], "hidden_dim": cfg["hidden_dim"]} + nb_imports.load_weights_into_olmo(ours, param_config, hf_model.state_dict()) + + ours.eval() + hf_model.eval() + return ours, hf_model + + +def _attach_debug_hooks(model, is_hf): + traces = {} + handles = [] + + def hook(name): + def _record(_, __, output): + traces[name] = output.detach().to(torch.float32).cpu() + return _record + + if is_hf: + core = model.model + handles.append(core.embed_tokens.register_forward_hook(hook("embedding"))) + for idx, layer in enumerate(core.layers): + handles.append(layer.register_forward_hook(hook(f"block_{idx}"))) + handles.append(core.norm.register_forward_hook(hook("final_norm"))) + handles.append(model.lm_head.register_forward_hook(hook("logits"))) + else: + handles.append(model.tok_emb.register_forward_hook(hook("embedding"))) + for idx, block in enumerate(model.blocks): + handles.append(block.register_forward_hook(hook(f"block_{idx}"))) + handles.append(model.final_norm.register_forward_hook(hook("final_norm"))) + handles.append(model.out_head.register_forward_hook(hook("logits"))) + + return traces, handles + + +def _layer_sort_key(name): + if name == "embedding": + return (0, 0) + if name.startswith("block_"): + idx = int(name.split("_")[1]) + return (1, idx) + if name == "final_norm": + return (2, 0) + if name == "logits": + return (3, 0) + return (4, name) + + +def layerwise_differences(ours, hf_model, input_ids, rtol=1e-5, atol=1e-5): + ours_traces, ours_handles = _attach_debug_hooks(ours, is_hf=False) + hf_traces, hf_handles = _attach_debug_hooks(hf_model, is_hf=True) + + try: + with torch.inference_mode(): + ours(input_ids) + hf_model(input_ids) + finally: + for h in ours_handles + hf_handles: + h.remove() + + layer_names = sorted(set(ours_traces) | set(hf_traces), key=_layer_sort_key) + results = [] + for name in layer_names: + ours_tensor = ours_traces.get(name) + hf_tensor = hf_traces.get(name) + + if ours_tensor is None or hf_tensor is None: + results.append( + { + "name": name, + "status": "missing", + "ours_shape": None if ours_tensor is None else tuple(ours_tensor.shape), + "hf_shape": None if hf_tensor is None else tuple(hf_tensor.shape), + "max_diff": None, + "mean_abs_diff": None, + } + ) + continue + + shapes_match = ours_tensor.shape == hf_tensor.shape + if not shapes_match: + results.append( + { + "name": name, + "status": "shape_mismatch", + "ours_shape": tuple(ours_tensor.shape), + "hf_shape": tuple(hf_tensor.shape), + "max_diff": None, + "mean_abs_diff": None, + } + ) + continue + + diff = (ours_tensor - hf_tensor).abs() + max_diff = float(diff.max().item()) + mean_diff = float(diff.mean().item()) + allclose = torch.allclose(ours_tensor, hf_tensor, rtol=rtol, atol=atol) + results.append( + { + "name": name, + "status": "ok" if allclose else "mismatch", + "ours_shape": tuple(ours_tensor.shape), + "hf_shape": tuple(hf_tensor.shape), + "max_diff": max_diff, + "mean_abs_diff": mean_diff, + } + ) + return results + + +def first_mismatch(differences): + for diff in differences: + if diff["status"] != "ok": + return diff + return None + + +def format_report(differences): + lines = [] + for diff in sorted(differences, key=lambda d: _layer_sort_key(d["name"])): + if diff["status"] == "ok": + lines.append(f"[OK] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}") + elif diff["status"] == "mismatch": + lines.append( + f"[DIFF] {diff['name']}: max={diff['max_diff']:.2e}, mean={diff['mean_abs_diff']:.2e}" + ) + elif diff["status"] == "shape_mismatch": + lines.append( + f"[SHAPE] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}" + ) + else: + lines.append(f"[MISSING] {diff['name']}: ours={diff['ours_shape']}, hf={diff['hf_shape']}") + return "\n".join(lines) + + +if __name__ == "__main__": + transformers_available = importlib.util.find_spec("transformers") is not None + if not transformers_available: + raise SystemExit("transformers is not installed; install it to run the debugger.") + + nb_imports = load_notebook_defs() + cfg = tiny_debug_config() + + ours_model, hf_model = build_olmo3_pair(nb_imports, cfg) + torch.manual_seed(0) + input_ids = torch.randint(0, cfg["vocab_size"], (1, cfg["context_length"]), dtype=torch.long) + diffs = layerwise_differences(ours_model, hf_model, input_ids) + print(format_report(diffs)) diff --git a/ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py b/ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py new file mode 100644 index 0000000..5675e0e --- /dev/null +++ b/ch05/13_olmo3/tests/test_olmo3_kvcache_nb.py @@ -0,0 +1,142 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +transformers_installed = importlib.util.find_spec("transformers") is not None + + +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3-plus-kv-cache.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8 + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "context_length": 64, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "n_kv_heads": 1, # 4 query heads, 1 KV groups -> group_size = 4 + "attention_bias": False, + "attention_dropout": 0.0, + "sliding_window": 4, + "layer_types": ["full_attention"] * 2, + + # RoPE config + "rope_base": 10_000.0, + "rope_attention_factor": 1.0, + "rope_type": "default", + "rope_factor": 1.0, + "rope_orig_max": 64, + "rms_norm_eps": 1e-6, + "dtype": torch.float32, + } + +@torch.inference_mode() +def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports): + torch.manual_seed(123) + model = nb_imports.Olmo3Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \ + f"Expected shape (1, seq_len, vocab_size), got {out.shape}" + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_olmo3_base_equivalence_with_transformers(nb_imports): + from transformers import Olmo3Config, Olmo3ForCausalLM + + # Tiny config so the test is fast + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_heads": 2, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + + # required by TransformerBlock + "attention_bias": False, + + # required by RMSNorm and RoPE setup in Olmo3Model + "rms_norm_eps": 1e-6, + "rope_base": 1_000_000.0, + "rope_attention_factor": 1.0, + "rope_type": "default", + "rope_factor": 1.0, + "rope_orig_max": 8, + + # extra HF-only stuff + "rope_local_base": 10_000.0, + } + + model = nb_imports.Olmo3Model(cfg) + + hf_cfg = Olmo3Config( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_heads"], + rope_theta=cfg["rope_base"], + rope_local_base_freq=cfg["rope_local_base"], + layer_types=cfg["layer_types"], + sliding_window=cfg["sliding_window"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + query_pre_attn_scalar=cfg["query_pre_attn_scalar"], + rope_scaling={"rope_type": "default"}, + qk_norm=cfg["qk_norm"], + rms_norm_eps=cfg["rms_norm_eps"], + ) + hf_model = Olmo3ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = { + "n_layers": cfg["n_layers"], + "hidden_dim": cfg["hidden_dim"], + } + nb_imports.load_weights_into_olmo(model, param_config, hf_state) + + x = torch.randint( + 0, + cfg["vocab_size"], + (2, cfg["context_length"]), + dtype=torch.long, + ) + ours_logits = model(x) + theirs_logits = hf_model(x).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) diff --git a/ch05/13_olmo3/tests/test_olmo3_nb.py b/ch05/13_olmo3/tests/test_olmo3_nb.py new file mode 100644 index 0000000..fa528dc --- /dev/null +++ b/ch05/13_olmo3/tests/test_olmo3_nb.py @@ -0,0 +1,142 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +import importlib +from pathlib import Path + +import pytest +import torch + +from llms_from_scratch.utils import import_definitions_from_notebook + + +transformers_installed = importlib.util.find_spec("transformers") is not None + + +@pytest.fixture +def nb_imports(): + nb_dir = Path(__file__).resolve().parents[1] + mod = import_definitions_from_notebook(nb_dir, "standalone-olmo3.ipynb") + return mod + + +@pytest.fixture +def dummy_input(): + torch.manual_seed(123) + return torch.randint(0, 100, (1, 8)) # batch size 1, seq length 8 + + +@pytest.fixture +def dummy_cfg_base(): + return { + "vocab_size": 100, + "context_length": 64, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "n_kv_heads": 1, # 4 query heads, 1 KV groups -> group_size = 4 + "attention_bias": False, + "attention_dropout": 0.0, + "sliding_window": 4, + "layer_types": ["full_attention"] * 2, + + # RoPE config + "rope_base": 10_000.0, + "rope_attention_factor": 1.0, + "rope_type": "default", + "rope_factor": 1.0, + "rope_orig_max": 64, + "rms_norm_eps": 1e-6, + "dtype": torch.float32, + } + +@torch.inference_mode() +def test_dummy_olmo3_forward(dummy_cfg_base, dummy_input, nb_imports): + torch.manual_seed(123) + model = nb_imports.Olmo3Model(dummy_cfg_base) + out = model(dummy_input) + assert out.shape == (1, dummy_input.size(1), dummy_cfg_base["vocab_size"]), \ + f"Expected shape (1, seq_len, vocab_size), got {out.shape}" + + +@torch.inference_mode() +@pytest.mark.skipif(not transformers_installed, reason="transformers not installed") +def test_olmo3_base_equivalence_with_transformers(nb_imports): + from transformers import Olmo3Config, Olmo3ForCausalLM + + # Tiny config so the test is fast + cfg = { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_heads": 2, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + + # required by TransformerBlock + "attention_bias": False, + + # required by RMSNorm and RoPE setup in Olmo3Model + "rms_norm_eps": 1e-6, + "rope_base": 1_000_000.0, + "rope_attention_factor": 1.0, + "rope_type": "default", + "rope_factor": 1.0, + "rope_orig_max": 8, + + # extra HF-only stuff + "rope_local_base": 10_000.0, + } + + model = nb_imports.Olmo3Model(cfg) + + hf_cfg = Olmo3Config( + vocab_size=cfg["vocab_size"], + max_position_embeddings=cfg["context_length"], + hidden_size=cfg["emb_dim"], + num_attention_heads=cfg["n_heads"], + num_hidden_layers=cfg["n_layers"], + intermediate_size=cfg["hidden_dim"], + head_dim=cfg["head_dim"], + num_key_value_heads=cfg["n_kv_heads"], + rope_theta=cfg["rope_base"], + rope_local_base_freq=cfg["rope_local_base"], + layer_types=cfg["layer_types"], + sliding_window=cfg["sliding_window"], + tie_word_embeddings=False, + attn_implementation="eager", + torch_dtype=torch.float32, + query_pre_attn_scalar=cfg["query_pre_attn_scalar"], + rope_scaling={"rope_type": "default"}, + qk_norm=cfg["qk_norm"], + rms_norm_eps=cfg["rms_norm_eps"], + ) + hf_model = Olmo3ForCausalLM(hf_cfg) + + hf_state = hf_model.state_dict() + param_config = { + "n_layers": cfg["n_layers"], + "hidden_dim": cfg["hidden_dim"], + } + nb_imports.load_weights_into_olmo(model, param_config, hf_state) + + x = torch.randint( + 0, + cfg["vocab_size"], + (2, cfg["context_length"]), + dtype=torch.long, + ) + ours_logits = model(x) + theirs_logits = hf_model(x).logits + torch.testing.assert_close(ours_logits, theirs_logits, rtol=1e-5, atol=1e-5) diff --git a/ch05/README.md b/ch05/README.md index a6cefbc..ed7516d 100644 --- a/ch05/README.md +++ b/ch05/README.md @@ -13,14 +13,25 @@ - [04_learning_rate_schedulers](04_learning_rate_schedulers) contains code implementing a more sophisticated training function including learning rate schedulers and gradient clipping - [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script - [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM -- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI - [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently - [09_extending-tokenizers](09_extending-tokenizers) contains a from-scratch implementation of the GPT-2 BPE tokenizer - [10_llm-training-speed](10_llm-training-speed) shows PyTorch performance tips to improve the LLM training speed + +  +## LLM Architectures From Scratch + + + +  + + +- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI - [11_qwen3](11_qwen3) A from-scratch implementation of Qwen3 0.6B and Qwen3 30B-A3B (Mixture-of-Experts) including code to load the pretrained weights of the base, reasoning, and coding model variants - [12_gemma3](12_gemma3) A from-scratch implementation of Gemma 3 270M and alternative with KV cache, including code to load the pretrained weights +- [13_olmo3](13_olmo3) A from-scratch implementation of Olmo 3 7B and 32B (Base, Instruct, and Think variants) and alternative with KV cache, including code to load the pretrained weights - +  +## Code-Along Video for This Chapter