2025-11-22 22:42:18 -06:00
{
"cells": [
{
"cell_type": "markdown",
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c",
"metadata": {
"id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52",
"metadata": {
"id": "efde77f2-6af3-4781-8597-89ecd3f41a52"
},
"source": [
"# Olmo 3 From Scratch (A Standalone Notebook)"
]
},
{
"cell_type": "markdown",
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d",
"metadata": {
"id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d"
},
"source": [
"- This notebook is purposefully minimal and focuses on the code to re-implement Olmo 3 7B and 32 models from Allen AI in pure PyTorch without relying on other external LLM libraries; Olmo 3 is interesting because it is currently the leading fully open-source model\n",
"- For more information, see the official [Olmo 3 announcement](https://allenai.org/blog/olmo3) and model cards:\n",
" - [Olmo-3-1025-7B](https://huggingface.co/allenai/Olmo-3-1025-7B) (base model)\n",
" - [Olmo-3-7B-Instruct](https://huggingface.co/allenai/Olmo-3-7B-Instruct)\n",
" - [Olmo-3-7B-Think](https://huggingface.co/allenai/Olmo-3-7B-Think)\n",
"- Note that there are also 32B versions, which are not listed above for brevity; you can find a complete list [here](https://huggingface.co/collections/allenai/olmo-3-post-training)\n",
"- Below is a side-by-side comparison with Qwen3 8B as a reference model; if you are interested in the Qwen3 0.6B standalone notebook, you can find it [here](../11_qwen3)\n",
"<br>\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3.webp\">\n",
" \n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/olmo3/olmo3-pipeline.webp\">\n",
" \n",
" \n",
"- About the code:\n",
" - all code is my own code, mapping the Olmo 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7c201adb-747e-437b-9a62-442802941e01",
"metadata": {},
"outputs": [],
"source": [
"# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface_hub version: 0.35.0\n",
"tokenizers version: 0.22.1\n",
"torch version: 2.9.1+cu130\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"huggingface_hub\", # to download pretrained weights\n",
" \"tokenizers\", # to implement the tokenizer\n",
" \"torch\", # to implement the model\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "07e96fbb-8e16-4f6d-835f-c6159321280b",
"metadata": {},
"source": [
"- Note that there are three model types, and each of the four model types comes in a 7B and 32B size:\n",
"1. Base (`Olmo-3-1025-7B` and `Olmo-3-1125-32B`)\n",
"2. Instruct (`Olmo-3-7B/32B-Think`)\n",
"3. Reasoning (`Olmo-3-32B/7B-Think`)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "70a90338-624a-4706-aa55-6b4358070194",
"metadata": {},
"outputs": [],
"source": [
"# Select which model to use\n",
"\n",
"# USE_MODEL = \"Olmo-3-1025-7B\"\n",
"# USE_MODEL = \"Olmo-3-1125-32B\"\n",
"USE_MODEL = \"Olmo-3-7B-Instruct\"\n",
"# USE_MODEL = \"Olmo-3-32B-Instruct\"\n",
"# USE_MODEL = \"Olmo-3-7B-Think\"\n",
"# USE_MODEL = \"Olmo-3-32B-Think\"\n",
"# USE_MODEL = \"Olmo-3-7B-RLZero-IF\""
]
},
{
"cell_type": "markdown",
"id": "1899ab4b-e1c2-4215-b3d1-ed00d52e4576",
"metadata": {},
"source": [
"- In addition to the checkpoints listed above, you can also use the intermediate checkpoints listed [here](https://huggingface.co/collections/allenai/olmo-3-post-training); since they all have the same architecture, they are all compatible with this notebook"
]
},
{
"cell_type": "markdown",
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d",
"metadata": {
"id": "653410a6-dd2b-4eb2-a722-23d9782e726d"
},
"source": [
" \n",
"# 1. Architecture code"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
"metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
" self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
"\n",
" def forward(self, x):\n",
" x_fc1 = self.fc1(x)\n",
" x_fc2 = self.fc2(x)\n",
" x = nn.functional.silu(x_fc1) * x_fc2\n",
" return self.fc3(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "56715760-37e1-433e-89da-04864c139a9e",
"metadata": {},
"outputs": [],
"source": [
"class RMSNorm(nn.Module):\n",
" def __init__(self, emb_dim, eps=1e-6):\n",
" super().__init__()\n",
" self.eps = eps\n",
" self.weight = nn.Parameter(torch.ones(emb_dim))\n",
"\n",
" def forward(self, x):\n",
" input_dtype = x.dtype\n",
" x_f = x.float()\n",
" var = x_f.pow(2).mean(dim=-1, keepdim=True)\n",
" x_norm = x_f * torch.rsqrt(var + self.eps)\n",
" return (self.weight * x_norm).to(input_dtype)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
"metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
},
"outputs": [],
"source": [
2026-01-03 16:59:57 -08:00
"import math\n",
"\n",
"\n",
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n",
2025-11-22 22:42:18 -06:00
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
2026-01-03 16:59:57 -08:00
" if rope_type == \"yarn\":\n",
" # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n",
"\n",
" def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n",
" \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n",
" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n",
"\n",
" def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n",
" \"\"\"Find dimension range bounds based on rotations\"\"\"\n",
" low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n",
" high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n",
" low = math.floor(low)\n",
" high = math.ceil(high)\n",
" return max(low, 0), min(high, dim - 1)\n",
"\n",
" def linear_ramp_factor(min_val, max_val, dim):\n",
" if min_val == max_val:\n",
" max_val += 0.001 # Prevent singularity\n",
" linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n",
" ramp_func = torch.clamp(linear_func, 0, 1)\n",
" return ramp_func\n",
"\n",
" # Base frequencies\n",
" pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n",
" inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n",
" inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n",
"\n",
" # Find the range where we blend between interpolation and extrapolation\n",
" low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n",
"\n",
" # Get n-dimensional rotational scaling corrected for extrapolation\n",
" inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n",
" inv_freq = (\n",
" inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n",
" + inv_freq_extrapolation * inv_freq_extrapolation_factor\n",
" )\n",
" else:\n",
" # Default RoPE\n",
" inv_freq = 1.0 / (\n",
" theta_base ** (\n",
" torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
" / head_dim\n",
" )\n",
2025-11-22 22:42:18 -06:00
" )\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Compute the base angles (shape: [context_length, head_dim // 2])\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n",
"\n",
" # Expand to full head_dim (shape: [context_length, head_dim])\n",
" angles = torch.cat([angles, angles], dim=1)\n",
"\n",
" # Precompute sine and cosine\n",
" cos = torch.cos(angles) * attention_factor\n",
" sin = torch.sin(angles) * attention_factor\n",
"\n",
" return cos, sin\n",
"\n",
"\n",
"def apply_rope(x, cos, sin, offset=0):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
"\n",
" # Split x into first half and second half\n",
" x1 = x[..., : head_dim // 2] # First half\n",
" x2 = x[..., head_dim // 2 :] # Second half\n",
"\n",
" # Adjust sin and cos shapes\n",
" cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n",
" sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)\n",
"\n",
" # Apply the rotary transformation\n",
" rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n",
"\n",
" # It's ok to use lower-precision after applying cos and sin rotation\n",
" return x_rotated.to(dtype=x.dtype)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
},
"outputs": [],
"source": [
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(self, d_in, num_heads, num_kv_groups, head_dim, attention_bias=False, dtype=None, sliding_window=None, attn_type=\"full_attention\"):\n",
" super().__init__()\n",
" assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.num_kv_groups = num_kv_groups\n",
" self.group_size = num_heads // num_kv_groups\n",
"\n",
" self.head_dim = head_dim\n",
" self.d_out = num_heads * head_dim\n",
" self.attn_type = attn_type\n",
" self.sliding_window = sliding_window if attn_type == \"sliding_attention\" else None\n",
"\n",
" # Projections\n",
" self.W_query = nn.Linear(d_in, self.d_out, bias=attention_bias, dtype=dtype)\n",
" self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n",
" self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=attention_bias, dtype=dtype)\n",
" self.out_proj = nn.Linear(self.d_out, d_in, bias=attention_bias, dtype=dtype)\n",
"\n",
" # Olmo3-style RMSNorm over the flattened projections\n",
" self.q_norm = RMSNorm(self.d_out)\n",
" self.k_norm = RMSNorm(num_kv_groups * head_dim)\n",
"\n",
" def forward(self, x, mask, cos, sin, start_pos=0, cache=None):\n",
" b, num_tokens, _ = x.shape\n",
"\n",
" # Apply projections\n",
" queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)\n",
" keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
" values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)\n",
"\n",
" # Normalize q and k\n",
" queries = self.q_norm(queries)\n",
" keys_new = self.k_norm(keys)\n",
"\n",
" # Reshape to (b, heads, seq, head_dim)\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
" values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)\n",
"\n",
" # Cache unrotated K/V\n",
" prev_len = 0\n",
" if cache is not None:\n",
" prev_k, prev_v = cache\n",
" if prev_k is not None:\n",
" prev_len = prev_k.size(2)\n",
" keys_cat_raw = torch.cat([prev_k, keys_new], dim=2)\n",
" values_cat_raw = torch.cat([prev_v, values_new], dim=2)\n",
" else:\n",
" keys_cat_raw = keys_new\n",
" values_cat_raw = values_new\n",
" else:\n",
" keys_cat_raw = keys_new\n",
" values_cat_raw = values_new\n",
"\n",
" # Apply RoPE with offsets for cached tokens\n",
" queries = apply_rope(queries, cos, sin, offset=start_pos)\n",
" keys = apply_rope(keys_cat_raw, cos, sin, offset=start_pos - prev_len)\n",
"\n",
" # Expand KV groups to full head count\n",
" if self.group_size > 1:\n",
" keys = keys.repeat_interleave(self.group_size, dim=1)\n",
" values = values_cat_raw.repeat_interleave(self.group_size, dim=1)\n",
" else:\n",
" values = values_cat_raw\n",
"\n",
" # Scaling before the matmul seems to be a bit more stable for Olmo\n",
" scale = self.head_dim ** -0.5 # Python float\n",
" queries = queries * scale\n",
"\n",
" # Update cache with unrotated K/V\n",
" if cache is not None and cache[0] is not None:\n",
" next_cache = (\n",
" torch.cat([cache[0], keys_new], dim=2),\n",
" torch.cat([cache[1], values_new], dim=2),\n",
" )\n",
" else:\n",
" next_cache = (keys_new, values_new)\n",
"\n",
" # Attention\n",
" attn_scores = queries @ keys.transpose(2, 3)\n",
" if mask is not None:\n",
" attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
"\n",
" attn_weights = torch.softmax(attn_scores, dim=-1)\n",
" context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)\n",
" out = self.out_proj(context)\n",
"\n",
" return out, next_cache"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "13eb3430-0c06-4fe2-a005-217205eee21e",
"metadata": {},
"outputs": [],
"source": [
"class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg, attn_type):\n",
" super().__init__()\n",
" self.attn_type = attn_type\n",
" self.sliding_window = cfg[\"sliding_window\"]\n",
" self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_heads\"],\n",
" head_dim=cfg[\"head_dim\"],\n",
" attention_bias=cfg[\"attention_bias\"],\n",
" dtype=cfg[\"dtype\"],\n",
" sliding_window=cfg[\"sliding_window\"],\n",
" attn_type=attn_type,\n",
" )\n",
" self.ff = FeedForward(cfg)\n",
" self.post_attention_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
" self.post_feedforward_layernorm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
"\n",
" def forward(self, x, mask_global, mask_local, cos, sin, start_pos=0, cache=None):\n",
" shortcut = x\n",
" if self.attn_type == \"sliding_attention\":\n",
" if cache is not None and isinstance(cache, tuple):\n",
" prev_k, _ = cache\n",
" prev_len = prev_k.size(2) if prev_k is not None else 0\n",
" else:\n",
" prev_len = 0\n",
" eff_kv_len = prev_len + x.size(1)\n",
" attn_mask = mask_local[..., -eff_kv_len:]\n",
" else:\n",
" attn_mask = mask_global\n",
"\n",
" x_attn, next_cache = self.att(x, attn_mask, cos, sin, start_pos=start_pos, cache=cache)\n",
" if next_cache is not None and self.attn_type == \"sliding_attention\":\n",
" k, v = next_cache\n",
" if k.size(2) > self.sliding_window:\n",
" k = k[:, :, -self.sliding_window:, :]\n",
" v = v[:, :, -self.sliding_window:, :]\n",
" next_cache = (k, v)\n",
"\n",
" x_attn = self.post_attention_layernorm(x_attn)\n",
" x = shortcut + x_attn\n",
"\n",
" shortcut = x\n",
" x_ffn = self.ff(x)\n",
" x_ffn = self.post_feedforward_layernorm(x_ffn)\n",
" x = shortcut + x_ffn\n",
" return x, next_cache"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
},
"outputs": [],
"source": [
"class Olmo3Model(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" assert cfg[\"layer_types\"] is not None and len(cfg[\"layer_types\"]) == cfg[\"n_layers\"]\n",
"\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
" self.blocks = nn.ModuleList([TransformerBlock(cfg, attn_type) for attn_type in cfg[\"layer_types\"]])\n",
" self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=cfg[\"rms_norm_eps\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
" self.cfg = cfg\n",
" self.current_pos = 0\n",
"\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"head_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" theta_base=cfg[\"rope_base\"],\n",
" attention_factor=cfg[\"rope_attention_factor\"],\n",
" rope_type=cfg[\"rope_type\"],\n",
" rope_factor=cfg[\"rope_factor\"],\n",
" rope_orig_max=cfg[\"rope_orig_max\"],\n",
" dtype=torch.float32,\n",
" )\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
"\n",
" def create_masks(self, cur_len, device, pos_start=0, pos_end=None):\n",
" if pos_end is None:\n",
" pos_end = cur_len\n",
" total_len = pos_end\n",
"\n",
" ones = torch.ones((total_len, total_len), dtype=torch.bool, device=device)\n",
" # mask_global_full (future is masked: j > i)\n",
" # j: 0 1 2 3 4 5 6 7\n",
" # i\n",
" # 0: 0 1 1 1 1 1 1 1\n",
" # 1: 0 0 1 1 1 1 1 1\n",
" # 2: 0 0 0 1 1 1 1 1\n",
" # 3: 0 0 0 0 1 1 1 1\n",
" # 4: 0 0 0 0 0 1 1 1\n",
" # 5: 0 0 0 0 0 0 1 1\n",
" # 6: 0 0 0 0 0 0 0 1\n",
" # 7: 0 0 0 0 0 0 0 0\n",
" mask_global_full = torch.triu(ones, diagonal=1)\n",
"\n",
" # far_past (too far back is masked: i - j >= sliding_window)\n",
" # where sliding_window = 4\n",
" # j: 0 1 2 3 4 5 6 7\n",
" # i\n",
" # 0: 0 0 0 0 0 0 0 0\n",
" # 1: 0 0 0 0 0 0 0 0\n",
" # 2: 0 0 0 0 0 0 0 0\n",
" # 3: 0 0 0 0 0 0 0 0\n",
" # 4: 1 0 0 0 0 0 0 0\n",
" # 5: 1 1 0 0 0 0 0 0\n",
" # 6: 1 1 1 0 0 0 0 0\n",
" # 7: 1 1 1 1 0 0 0 0\n",
" far_past_full = torch.triu(ones, diagonal=self.cfg[\"sliding_window\"]).T\n",
"\n",
" # Local (sliding_window) = future OR far-past\n",
" # mask_local\n",
" # j: 0 1 2 3 4 5 6 7\n",
" # i\n",
" # 0: 0 1 1 1 1 1 1 1\n",
" # 1: 0 0 1 1 1 1 1 1\n",
" # 2: 0 0 0 1 1 1 1 1\n",
" # 3: 0 0 0 0 1 1 1 1\n",
" # 4: 1 0 0 0 0 1 1 1\n",
" # 5: 1 1 0 0 0 0 1 1\n",
" # 6: 1 1 1 0 0 0 0 1\n",
" # 7: 1 1 1 1 0 0 0 0\n",
" mask_local_full = mask_global_full | far_past_full\n",
"\n",
" row_slice = slice(pos_start, pos_end)\n",
" mask_global = mask_global_full[row_slice, :pos_end][None, None, :, :]\n",
" mask_local = mask_local_full[row_slice, :pos_end][None, None, :, :]\n",
" return mask_global, mask_local\n",
"\n",
" def forward(self, input_ids, cache=None):\n",
" b, seq_len = input_ids.shape\n",
" x = self.tok_emb(input_ids)\n",
"\n",
" if cache is not None:\n",
" pos_start = self.current_pos\n",
" pos_end = pos_start + seq_len\n",
" self.current_pos = pos_end\n",
" mask_global, mask_local = self.create_masks(\n",
" cur_len=seq_len, device=x.device, pos_start=pos_start, pos_end=pos_end\n",
" )\n",
" else:\n",
" pos_start = 0\n",
" mask_global, mask_local = self.create_masks(\n",
" cur_len=seq_len, device=x.device, pos_start=0, pos_end=seq_len\n",
" )\n",
"\n",
" cos = self.cos\n",
" sin = self.sin\n",
"\n",
" for i, block in enumerate(self.blocks):\n",
" blk_cache = cache.get(i) if cache is not None else None\n",
" x, new_blk_cache = block(\n",
" x,\n",
" mask_global=mask_global,\n",
" mask_local=mask_local,\n",
" cos=cos,\n",
" sin=sin,\n",
" start_pos=pos_start,\n",
" cache=blk_cache,\n",
" )\n",
"\n",
" if cache is not None:\n",
" cache.update(i, new_blk_cache)\n",
"\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
" return logits\n",
"\n",
" def reset_kv_cache(self):\n",
" self.current_pos = 0"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4f5271e8-ff28-4aaa-bbb2-f73582e6d228",
"metadata": {},
"outputs": [],
"source": [
"class KVCache:\n",
" def __init__(self, n_layers):\n",
" self.cache = [None] * n_layers\n",
"\n",
" def get(self, layer_idx):\n",
" return self.cache[layer_idx]\n",
"\n",
" def update(self, layer_idx, value):\n",
" self.cache[layer_idx] = value\n",
"\n",
" def get_all(self):\n",
" return self.cache\n",
"\n",
" def reset(self):\n",
" for i in range(len(self.cache)):\n",
" self.cache[i] = None"
]
},
{
"cell_type": "markdown",
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48",
"metadata": {
"id": "be2d201f-74ad-4d63-ab9c-601b00674a48"
},
"source": [
" \n",
"# 2. Initialize model"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
},
"outputs": [],
"source": [
"OLMO3_CONFIG_7B = {\n",
" \"vocab_size\": 100_278,\n",
" \"context_length\": 65_536,\n",
" \"emb_dim\": 4_096,\n",
" \"n_heads\": 32,\n",
" \"n_layers\": 32,\n",
" \"hidden_dim\": 11_008,\n",
" \"head_dim\": 128,\n",
" \"n_kv_heads\": 32,\n",
" \"attention_bias\": False,\n",
" \"attention_dropout\": 0.0,\n",
" \"sliding_window\": 4_096,\n",
" \"layer_types\": [\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" ],\n",
" \"rope_base\": 500_000.0,\n",
" \"rope_attention_factor\": 1.2079441541679836,\n",
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
2026-01-03 16:59:57 -08:00
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
2025-11-22 22:42:18 -06:00
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
" \"pad_token_id\": 100_277,\n",
"}\n",
"\n",
"OLMO3_CONFIG_32B = {\n",
" \"vocab_size\": 100_278,\n",
" \"context_length\": 65_536,\n",
" \"emb_dim\": 5_120,\n",
" \"n_heads\": 40,\n",
" \"n_layers\": 64,\n",
" \"hidden_dim\": 27_648,\n",
" \"head_dim\": 128,\n",
" \"n_kv_heads\": 8,\n",
" \"attention_bias\": False,\n",
" \"attention_dropout\": 0.0,\n",
" \"sliding_window\": 4_096,\n",
" \"layer_types\": [\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"sliding_attention\",\n",
" \"full_attention\",\n",
" ],\n",
" \"rope_base\": 500_000.0,\n",
" \"rope_attention_factor\": 1.2079441541679836,\n",
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
2026-01-03 16:59:57 -08:00
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
2025-11-22 22:42:18 -06:00
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
" \"pad_token_id\": 100_277,\n",
"}\n",
"\n",
"OLMO3_CONFIG = OLMO3_CONFIG_32B if \"32B\" in USE_MODEL else OLMO3_CONFIG_7B"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"model = Olmo3Model(OLMO3_CONFIG)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "eaf86265-4e9d-4024-9ed0-99076944e304",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Olmo3Model(\n",
" (tok_emb): Embedding(100278, 4096)\n",
" (blocks): ModuleList(\n",
" (0-31): 32 x TransformerBlock(\n",
" (att): GroupedQueryAttention(\n",
" (W_query): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_key): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (W_value): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (out_proj): Linear(in_features=4096, out_features=4096, bias=False)\n",
" (q_norm): RMSNorm()\n",
" (k_norm): RMSNorm()\n",
" )\n",
" (ff): FeedForward(\n",
" (fc1): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (fc2): Linear(in_features=4096, out_features=11008, bias=False)\n",
" (fc3): Linear(in_features=11008, out_features=4096, bias=False)\n",
" )\n",
" (post_attention_layernorm): RMSNorm()\n",
" (post_feedforward_layernorm): RMSNorm()\n",
" )\n",
" )\n",
" (final_norm): RMSNorm()\n",
" (out_head): Linear(in_features=4096, out_features=100278, bias=False)\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "markdown",
"id": "90aca91d-4bee-45ce-993a-4ec5393abe2b",
"metadata": {},
"source": [
"- A quick check that the forward pass works before continuing:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "adf0a6b7-b688-42c9-966e-c223d34db99f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2026-01-03 16:59:57 -08:00
"tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n",
" [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n",
" [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n",
2025-11-22 22:42:18 -06:00
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model(torch.tensor([1, 2, 3]).unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/rasbt/jupyterlab/reasoning/.venv/lib/python3.12/site-packages/torch/cuda/__init__.py:283: UserWarning: \n",
" Found GPU0 NVIDIA GB10 which is of cuda capability 12.1.\n",
" Minimum and Maximum cuda capability supported by this version of PyTorch is\n",
" (8.0) - (12.0)\n",
" \n",
" warnings.warn(\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"model.to(device);"
]
},
{
"cell_type": "markdown",
"id": "c172f89f-d301-439f-b809-46169e5f5945",
"metadata": {
"id": "c172f89f-d301-439f-b809-46169e5f5945"
},
"source": [
" \n",
2026-02-17 19:44:56 -05:00
"# 3. Load pretrained weights"
2025-11-22 22:42:18 -06:00
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e"
},
"outputs": [],
"source": [
"def load_weights_into_olmo(model, param_config, params):\n",
" def assign(left, right, tensor_name=\"unknown\"):\n",
" if left.shape != right.shape:\n",
" raise ValueError(\n",
" f\"Shape mismatch in tensor '{tensor_name}'. \"\n",
" f\"Left: {left.shape}, Right: {right.shape}\"\n",
" )\n",
" \n",
" with torch.no_grad():\n",
" if isinstance(right, torch.Tensor):\n",
" left.copy_(right)\n",
" else:\n",
" left.copy_(torch.as_tensor(right, dtype=left.dtype, device=left.device))\n",
" \n",
" return left\n",
"\n",
" # Token embedding\n",
" if \"model.embed_tokens.weight\" in params:\n",
" model.tok_emb.weight = assign(\n",
" model.tok_emb.weight,\n",
" params[\"model.embed_tokens.weight\"],\n",
" \"model.embed_tokens.weight\",\n",
" )\n",
"\n",
" for l in range(param_config[\"n_layers\"]):\n",
" block = model.blocks[l]\n",
" att = block.att\n",
"\n",
" # Q, K, V projections\n",
" att.W_query.weight = assign(\n",
" att.W_query.weight,\n",
" params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_proj.weight\",\n",
" )\n",
" att.W_key.weight = assign(\n",
" att.W_key.weight,\n",
" params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_proj.weight\",\n",
" )\n",
" att.W_value.weight = assign(\n",
" att.W_value.weight,\n",
" params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.v_proj.weight\",\n",
" )\n",
"\n",
" # Output projection\n",
" att.out_proj.weight = assign(\n",
" att.out_proj.weight,\n",
" params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n",
" f\"model.layers.{l}.self_attn.o_proj.weight\",\n",
" )\n",
"\n",
" # QK norms\n",
" att.q_norm.weight = assign(\n",
" att.q_norm.weight,\n",
" params[f\"model.layers.{l}.self_attn.q_norm.weight\"],\n",
" f\"model.layers.{l}.self_attn.q_norm.weight\",\n",
" )\n",
" att.k_norm.weight = assign(\n",
" att.k_norm.weight,\n",
" params[f\"model.layers.{l}.self_attn.k_norm.weight\"],\n",
" f\"model.layers.{l}.self_attn.k_norm.weight\",\n",
" )\n",
"\n",
" # Feedforward weights\n",
" block.ff.fc1.weight = assign(\n",
" block.ff.fc1.weight,\n",
" params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.gate_proj.weight\",\n",
" )\n",
" block.ff.fc2.weight = assign(\n",
" block.ff.fc2.weight,\n",
" params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.up_proj.weight\",\n",
" )\n",
" block.ff.fc3.weight = assign(\n",
" block.ff.fc3.weight,\n",
" params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n",
" f\"model.layers.{l}.mlp.down_proj.weight\",\n",
" )\n",
"\n",
" # Post-attention and post norms\n",
" block.post_attention_layernorm.weight = assign(\n",
" block.post_attention_layernorm.weight,\n",
" params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n",
" f\"model.layers.{l}.post_attention_layernorm.weight\",\n",
" )\n",
" block.post_feedforward_layernorm.weight = assign(\n",
" block.post_feedforward_layernorm.weight,\n",
" params[f\"model.layers.{l}.post_feedforward_layernorm.weight\"],\n",
" f\"model.layers.{l}.post_feedforward_layernorm.weight\",\n",
" )\n",
"\n",
" # Final normalization and output head\n",
" if \"model.norm.weight\" in params:\n",
" model.final_norm.weight = assign(\n",
" model.final_norm.weight,\n",
" params[\"model.norm.weight\"],\n",
" \"model.norm.weight\",\n",
" )\n",
"\n",
" if \"lm_head.weight\" in params:\n",
" model.out_head.weight = assign(\n",
" model.out_head.weight,\n",
" params[\"lm_head.weight\"],\n",
" \"lm_head.weight\",\n",
" )\n",
" else:\n",
" model.out_head.weight = model.tok_emb.weight\n",
" print(\"Model uses weight tying.\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 17,
"referenced_widgets": [
"9881b6995c3f49dc89e6992fd9ab660b",
"17a3174e65c54476b2e0d1faf8f011ca",
"1bbf2e62c0754d1593beb4105a7f1ac1",
"b82112e1dec645d98aa1c1ba64abcb61",
"271e2bd6a35e4a8b92de8697f7c0be5f",
"90a79523187446dfa692723b2e5833a7",
"431ffb83b8c14bf182f0430e07ea6154",
"a8f1b72a33dd4b548de23fbd95e0da18",
"25cc36132d384189acfbecc59483134b",
"bfd06423ad544218968648016e731a46",
"d029630b63ff44cf807ade428d2eb421"
]
},
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d"
},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0fcdf72bf5b646d39bf4ed84faeb3302",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 14 files: 0%| | 0/14 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import json\n",
"import os\n",
"from pathlib import Path\n",
"from safetensors.torch import load_file\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"repo_id = f\"allenai/{USE_MODEL}\"\n",
"local_dir = Path(repo_id).parts[-1]\n",
"\n",
"repo_dir = snapshot_download(repo_id=repo_id, local_dir=local_dir)\n",
"index_path = os.path.join(repo_dir, \"model.safetensors.index.json\")\n",
"with open(index_path, \"r\") as f:\n",
" index = json.load(f)\n",
"\n",
"weights_dict = {}\n",
"for filename in sorted(set(index[\"weight_map\"].values())):\n",
" shard_path = os.path.join(repo_dir, filename)\n",
" shard = load_file(shard_path)\n",
" weights_dict.update(shard)\n",
"\n",
"load_weights_into_olmo(model, OLMO3_CONFIG, weights_dict)\n",
"model.to(device)\n",
"del weights_dict"
]
},
{
"cell_type": "markdown",
"id": "6b345491-3510-4397-92d3-cd0a3fa3deee",
"metadata": {},
"source": [
" \n",
"# 4. Load tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "b68ab489-48e5-471e-a814-56cda2d60f81",
"metadata": {},
"outputs": [],
"source": [
"from tokenizers import Tokenizer\n",
"from huggingface_hub import hf_hub_download\n",
"\n",
"\n",
"class OlmoTokenizer:\n",
" def __init__(self, tokenizer_file_path, eos_token_id, pad_token_id):\n",
" tok_file = Path(tokenizer_file_path)\n",
" self._tok = Tokenizer.from_file(str(tok_file))\n",
" eos_from_tok = (\n",
" self._tok.token_to_id(\"<|endoftext|>\")\n",
" or self._tok.token_to_id(\"<end_of_turn>\")\n",
" )\n",
" self.eos_token_id = eos_from_tok if eos_from_tok is not None else eos_token_id\n",
" pad_from_tok = (\n",
" self._tok.token_to_id(\"<|pad|>\")\n",
" or self._tok.token_to_id(\"<pad>\")\n",
" )\n",
" self.pad_token_id = pad_from_tok if pad_from_tok is not None else pad_token_id\n",
"\n",
" def encode(self, text):\n",
" return self._tok.encode(text).ids\n",
"\n",
" def decode(self, ids):\n",
" return self._tok.decode(ids, skip_special_tokens=False)\n",
"\n",
"\n",
"def apply_chat_template(user_text):\n",
" return (\n",
" \"<|im_start|>user\\n\"\n",
" f\"{user_text}\\n\"\n",
" \"<|im_end|>\\n\"\n",
" \"<|im_start|>assistant\\n\"\n",
" )\n",
"\n",
"\n",
"tokenizer_file_path = os.path.join(local_dir, \"tokenizer.json\")\n",
"if not os.path.exists(tokenizer_file_path):\n",
" try:\n",
" tokenizer_file_path = hf_hub_download(repo_id=repo_id, filename=\"tokenizer.json\", local_dir=local_dir)\n",
" except Exception as e:\n",
" print(f\"Warning: failed to download tokenizer.json: {e}\")\n",
" tokenizer_file_path = \"tokenizer.json\"\n",
"\n",
"tokenizer = OlmoTokenizer(\n",
" tokenizer_file_path=tokenizer_file_path,\n",
" eos_token_id=OLMO3_CONFIG[\"eos_token_id\"],\n",
" pad_token_id=OLMO3_CONFIG[\"pad_token_id\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prompt = apply_chat_template(\"Give me a short intro to large language models in 3 sentences.\")\n",
"\n",
"input_token_ids = tokenizer.encode(prompt)\n",
"text = tokenizer.decode(input_token_ids)\n",
"text"
]
},
{
"cell_type": "markdown",
"id": "57d07df1-4401-4792-b549-7c4cc5632323",
"metadata": {
"id": "57d07df1-4401-4792-b549-7c4cc5632323"
},
"source": [
" \n",
"# 5. Generate text"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
},
"outputs": [],
"source": [
"def generate_text_basic_stream(model, token_ids, max_new_tokens, eos_token_id=None, context_size=None):\n",
"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
" model.reset_kv_cache()\n",
"\n",
" logits = model(token_ids, cache=cache)\n",
"\n",
" for _ in range(max_new_tokens):\n",
" next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
"\n",
" if (eos_token_id is not None\n",
" and torch.all(next_token == eos_token_id)):\n",
" break\n",
"\n",
" yield next_token\n",
"\n",
" token_ids = torch.cat([token_ids, next_token], dim=1)\n",
"\n",
" logits = model(next_token, cache=cache)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-01-03 16:59:57 -08:00
"Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now foundational in many modern AI applications.\n",
2025-11-22 22:42:18 -06:00
"\n",
"GPU memory used: 13.71 GB\n"
]
}
],
"source": [
"input_token_ids_tensor = torch.tensor(input_token_ids, device=device).unsqueeze(0)\n",
"\n",
"\n",
"if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
"\n",
"for token in generate_text_basic_stream(\n",
" model=model,\n",
" token_ids=input_token_ids_tensor,\n",
" max_new_tokens=500,\n",
" eos_token_id=tokenizer.eos_token_id\n",
"):\n",
" token_id = token.squeeze(0).tolist()\n",
" print(\n",
" tokenizer.decode(token_id),\n",
" end=\"\",\n",
" flush=True\n",
" )\n",
"\n",
"if torch.cuda.is_available():\n",
2026-02-17 19:44:56 -05:00
" def calc_gpu_gb(x):\n",
2025-11-22 22:42:18 -06:00
" return f\"{x / 1024 / 1024 / 1024:.2f} GB\"\n",
" \n",
2026-02-17 19:44:56 -05:00
" print(f\"\\n\\nGPU memory used: {calc_gpu_gb(torch.cuda.max_memory_allocated())}\")"
2025-11-22 22:42:18 -06:00
]
},
{
"cell_type": "markdown",
"id": "549324d6-5c71-4147-ae21-2e67675faa3d",
"metadata": {
"id": "549324d6-5c71-4147-ae21-2e67675faa3d"
},
"source": [
" \n",
"# What's next?"
]
},
{
"cell_type": "markdown",
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c",
"metadata": {
"id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c"
},
"source": [
"- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n",
"\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2026-02-17 19:44:56 -05:00
"version": "3.13.5"
2025-11-22 22:42:18 -06:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}