From 491fd584632f0f4d19cb4aabea4e933fe1f31dc9 Mon Sep 17 00:00:00 2001 From: Gerardo Moreno Date: Sat, 3 Jan 2026 16:59:57 -0800 Subject: [PATCH] Fix Olmo3 YaRN RoPE implementation bug (#940) * Olmo3 fix RoPE YaRN implementation * Update cell outputs * update olmo layer debugger --------- Co-authored-by: rasbt --- .../standalone-olmo3-plus-kv-cache.ipynb | 72 ++++++++++++---- ch05/13_olmo3/standalone-olmo3.ipynb | 82 +++++++++++++------ ch05/13_olmo3/tests/olmo3_layer_debugger.py | 45 +++++++++- 3 files changed, 157 insertions(+), 42 deletions(-) diff --git a/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb b/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb index 3296fef..e9cfa07 100644 --- a/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb +++ b/ch05/13_olmo3/standalone-olmo3-plus-kv-cache.ipynb @@ -206,25 +206,60 @@ }, "outputs": [], "source": [ - "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n", + "import math\n", + "\n", + "\n", + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", - " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (\n", - " theta_base ** (\n", - " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", - " / head_dim\n", + " if rope_type == \"yarn\":\n", + " # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n", + "\n", + " def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n", + " \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n", + " return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n", + "\n", + " def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n", + " \"\"\"Find dimension range bounds based on rotations\"\"\"\n", + " low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n", + " high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n", + " low = math.floor(low)\n", + " high = math.ceil(high)\n", + " return max(low, 0), min(high, dim - 1)\n", + "\n", + " def linear_ramp_factor(min_val, max_val, dim):\n", + " if min_val == max_val:\n", + " max_val += 0.001 # Prevent singularity\n", + " linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n", + " ramp_func = torch.clamp(linear_func, 0, 1)\n", + " return ramp_func\n", + "\n", + " # Base frequencies\n", + " pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n", + " inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n", + " inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n", + "\n", + " # Find the range where we blend between interpolation and extrapolation\n", + " low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n", + "\n", + " # Get n-dimensional rotational scaling corrected for extrapolation\n", + " inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n", + " inv_freq = (\n", + " inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n", + " + inv_freq_extrapolation * inv_freq_extrapolation_factor\n", + " )\n", + " else:\n", + " # Default RoPE\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", + " / head_dim\n", + " )\n", " )\n", - " )\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length, dtype=dtype)\n", "\n", - " # Optional YaRN scaling\n", - " if rope_type == \"yarn\":\n", - " positions = positions / rope_factor\n", - " positions = torch.clamp(positions, max=rope_orig_max - 1)\n", - "\n", " # Compute the base angles (shape: [context_length, head_dim // 2])\n", " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", "\n", @@ -642,6 +677,8 @@ " \"rope_type\": \"yarn\",\n", " \"rope_factor\": 8.0,\n", " \"rope_orig_max\": 8_192,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", " \"rms_norm_eps\": 1e-6,\n", " \"dtype\": torch.bfloat16,\n", " \"eos_token_id\": 100_257,\n", @@ -727,6 +764,8 @@ " \"rope_type\": \"yarn\",\n", " \"rope_factor\": 8.0,\n", " \"rope_orig_max\": 8_192,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", " \"rms_norm_eps\": 1e-6,\n", " \"dtype\": torch.bfloat16,\n", " \"eos_token_id\": 100_257,\n", @@ -810,9 +849,9 @@ { "data": { "text/plain": [ - "tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n", - " [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n", - " [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n", + "tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n", + " [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n", + " [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, @@ -1202,8 +1241,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Sure! Here’s a brief introduction to large language models: \n", - "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.\n", + "Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now foundational in many modern AI applications.\n", "\n", "GPU memory used: 13.71 GB\n" ] diff --git a/ch05/13_olmo3/standalone-olmo3.ipynb b/ch05/13_olmo3/standalone-olmo3.ipynb index 2c7f42f..ed8885f 100644 --- a/ch05/13_olmo3/standalone-olmo3.ipynb +++ b/ch05/13_olmo3/standalone-olmo3.ipynb @@ -206,25 +206,60 @@ }, "outputs": [], "source": [ - "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n", + "import math\n", + "\n", + "\n", + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", - " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (\n", - " theta_base ** (\n", - " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", - " / head_dim\n", + " if rope_type == \"yarn\":\n", + " # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n", + "\n", + " def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n", + " \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n", + " return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n", + "\n", + " def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n", + " \"\"\"Find dimension range bounds based on rotations\"\"\"\n", + " low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n", + " high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n", + " low = math.floor(low)\n", + " high = math.ceil(high)\n", + " return max(low, 0), min(high, dim - 1)\n", + "\n", + " def linear_ramp_factor(min_val, max_val, dim):\n", + " if min_val == max_val:\n", + " max_val += 0.001 # Prevent singularity\n", + " linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n", + " ramp_func = torch.clamp(linear_func, 0, 1)\n", + " return ramp_func\n", + "\n", + " # Base frequencies\n", + " pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n", + " inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n", + " inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n", + "\n", + " # Find the range where we blend between interpolation and extrapolation\n", + " low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n", + "\n", + " # Get n-dimensional rotational scaling corrected for extrapolation\n", + " inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n", + " inv_freq = (\n", + " inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n", + " + inv_freq_extrapolation * inv_freq_extrapolation_factor\n", + " )\n", + " else:\n", + " # Default RoPE\n", + " inv_freq = 1.0 / (\n", + " theta_base ** (\n", + " torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n", + " / head_dim\n", + " )\n", " )\n", - " )\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length, dtype=dtype)\n", "\n", - " # Optional YaRN scaling\n", - " if rope_type == \"yarn\":\n", - " positions = positions / rope_factor\n", - " positions = torch.clamp(positions, max=rope_orig_max - 1)\n", - "\n", " # Compute the base angles (shape: [context_length, head_dim // 2])\n", " angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n", "\n", @@ -539,6 +574,8 @@ " \"rope_type\": \"yarn\",\n", " \"rope_factor\": 8.0,\n", " \"rope_orig_max\": 8_192,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", " \"rms_norm_eps\": 1e-6,\n", " \"dtype\": torch.bfloat16,\n", " \"eos_token_id\": 100_257,\n", @@ -624,6 +661,8 @@ " \"rope_type\": \"yarn\",\n", " \"rope_factor\": 8.0,\n", " \"rope_orig_max\": 8_192,\n", + " \"beta_fast\": 32.0,\n", + " \"beta_slow\": 1.0,\n", " \"rms_norm_eps\": 1e-6,\n", " \"dtype\": torch.bfloat16,\n", " \"eos_token_id\": 100_257,\n", @@ -707,9 +746,9 @@ { "data": { "text/plain": [ - "tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n", - " [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n", - " [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n", + "tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n", + " [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n", + " [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n", " dtype=torch.bfloat16, grad_fn=)" ] }, @@ -1021,7 +1060,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 18, "id": "7b6df8bc-7308-468e-93ce-2d5529ea7866", "metadata": {}, "outputs": [ @@ -1031,7 +1070,7 @@ "'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'" ] }, - "execution_count": 24, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -1057,7 +1096,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 19, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1083,7 +1122,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 20, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1093,10 +1132,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Sure! Here’s a brief introduction to large language models: \n", - "Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating human-like text, learning from vast data. \n", - "They excel at many tasks across many languages and adapt to various tasks. \n", - "They power modern applications widely used in NLP solutions.\n", + "Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now integral to many modern technologies.\n", "\n", "GPU memory used: 13.70 GB\n" ] diff --git a/ch05/13_olmo3/tests/olmo3_layer_debugger.py b/ch05/13_olmo3/tests/olmo3_layer_debugger.py index fe58ba3..7f4664b 100644 --- a/ch05/13_olmo3/tests/olmo3_layer_debugger.py +++ b/ch05/13_olmo3/tests/olmo3_layer_debugger.py @@ -43,10 +43,51 @@ def tiny_debug_config(): } +def yarn_debug_config(): + return { + "vocab_size": 257, + "context_length": 8, + "emb_dim": 32, + "n_heads": 4, + "n_layers": 2, + "hidden_dim": 64, + "head_dim": 8, + "qk_norm": True, + "n_kv_heads": 2, + "sliding_window": 4, + "layer_types": ["full_attention", "full_attention"], + "dtype": torch.float32, + "query_pre_attn_scalar": 256, + "attention_bias": False, + "rms_norm_eps": 1e-6, + "rope_base": 500_000.0, + "rope_attention_factor": 1.2079441541679836, + "rope_type": "yarn", + "rope_factor": 8.0, + "rope_orig_max": 8192, + "beta_fast": 32.0, + "beta_slow": 1.0, + "rope_local_base": 10_000.0, + } + + def _hf_config_from_dict(cfg): if Olmo3Config is None: raise ImportError("transformers is required for the Olmo-3 debugger.") + rope_type = cfg.get("rope_type", "default") + rope_scaling = {"rope_type": rope_type} + if rope_type == "yarn": + rope_scaling.update( + { + "attention_factor": cfg.get("rope_attention_factor", 1.0), + "beta_fast": cfg.get("beta_fast", 32.0), + "beta_slow": cfg.get("beta_slow", 1.0), + "factor": cfg.get("rope_factor", 1.0), + "original_max_position_embeddings": cfg.get("rope_orig_max", 8192), + } + ) + return Olmo3Config( vocab_size=cfg["vocab_size"], max_position_embeddings=cfg["context_length"], @@ -64,7 +105,7 @@ def _hf_config_from_dict(cfg): attn_implementation="eager", torch_dtype=cfg.get("dtype", torch.float32), query_pre_attn_scalar=cfg.get("query_pre_attn_scalar", 256), - rope_scaling={"rope_type": cfg.get("rope_type", "default")}, + rope_scaling=rope_scaling, qk_norm=cfg.get("qk_norm", False), rms_norm_eps=cfg.get("rms_norm_eps", 1e-5), ) @@ -231,7 +272,7 @@ if __name__ == "__main__": raise SystemExit("transformers is not installed; install it to run the debugger.") nb_imports = load_notebook_defs() - cfg = tiny_debug_config() + cfg = yarn_debug_config() ours_model, hf_model = build_olmo3_pair(nb_imports, cfg) torch.manual_seed(0)