Fix Olmo3 YaRN RoPE implementation bug (#940)

* Olmo3 fix RoPE YaRN implementation

* Update cell outputs

* update olmo layer debugger

---------

Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
Gerardo Moreno
2026-01-03 16:59:57 -08:00
committed by GitHub
parent b26fa01381
commit 491fd58463
3 changed files with 157 additions and 42 deletions

View File

@@ -206,25 +206,60 @@
},
"outputs": [],
"source": [
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n",
"import math\n",
"\n",
"\n",
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (\n",
" theta_base ** (\n",
" torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
" / head_dim\n",
" if rope_type == \"yarn\":\n",
" # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n",
"\n",
" def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n",
" \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n",
" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n",
"\n",
" def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n",
" \"\"\"Find dimension range bounds based on rotations\"\"\"\n",
" low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n",
" high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n",
" low = math.floor(low)\n",
" high = math.ceil(high)\n",
" return max(low, 0), min(high, dim - 1)\n",
"\n",
" def linear_ramp_factor(min_val, max_val, dim):\n",
" if min_val == max_val:\n",
" max_val += 0.001 # Prevent singularity\n",
" linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n",
" ramp_func = torch.clamp(linear_func, 0, 1)\n",
" return ramp_func\n",
"\n",
" # Base frequencies\n",
" pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n",
" inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n",
" inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n",
"\n",
" # Find the range where we blend between interpolation and extrapolation\n",
" low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n",
"\n",
" # Get n-dimensional rotational scaling corrected for extrapolation\n",
" inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n",
" inv_freq = (\n",
" inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n",
" + inv_freq_extrapolation * inv_freq_extrapolation_factor\n",
" )\n",
" else:\n",
" # Default RoPE\n",
" inv_freq = 1.0 / (\n",
" theta_base ** (\n",
" torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
" / head_dim\n",
" )\n",
" )\n",
" )\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Optional YaRN scaling\n",
" if rope_type == \"yarn\":\n",
" positions = positions / rope_factor\n",
" positions = torch.clamp(positions, max=rope_orig_max - 1)\n",
"\n",
" # Compute the base angles (shape: [context_length, head_dim // 2])\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n",
"\n",
@@ -642,6 +677,8 @@
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
@@ -727,6 +764,8 @@
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
@@ -810,9 +849,9 @@
{
"data": {
"text/plain": [
"tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n",
" [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n",
" [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n",
"tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n",
" [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n",
" [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
@@ -1202,8 +1241,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Heres a brief introduction to large language models: \n",
"Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating text, learning from vast amounts of data, learning language, performing diverse tasks, assisting in many applications, and adapting various tasks.\n",
"Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now foundational in many modern AI applications.\n",
"\n",
"GPU memory used: 13.71 GB\n"
]

View File

@@ -206,25 +206,60 @@
},
"outputs": [],
"source": [
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, dtype=torch.float32):\n",
"import math\n",
"\n",
"\n",
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, attention_factor=1.0, rope_type=\"default\", rope_factor=1.0, rope_orig_max=8192, beta_fast=32.0, beta_slow=1.0, dtype=torch.float32):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n",
" # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (\n",
" theta_base ** (\n",
" torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
" / head_dim\n",
" if rope_type == \"yarn\":\n",
" # Compute YaRN-style frequency scaling (as per https://huggingface.co/papers/2309.00071)\n",
"\n",
" def find_correction_dim(num_rotations, dim, base, max_position_embeddings):\n",
" \"\"\"Inverse dimension formula to find the dimension based on the number of rotations\"\"\"\n",
" return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))\n",
"\n",
" def find_correction_range(low_rot, high_rot, dim, base, max_position_embeddings):\n",
" \"\"\"Find dimension range bounds based on rotations\"\"\"\n",
" low = find_correction_dim(low_rot, dim, base, max_position_embeddings)\n",
" high = find_correction_dim(high_rot, dim, base, max_position_embeddings)\n",
" low = math.floor(low)\n",
" high = math.ceil(high)\n",
" return max(low, 0), min(high, dim - 1)\n",
"\n",
" def linear_ramp_factor(min_val, max_val, dim):\n",
" if min_val == max_val:\n",
" max_val += 0.001 # Prevent singularity\n",
" linear_func = (torch.arange(dim, dtype=torch.float32) - min_val) / (max_val - min_val)\n",
" ramp_func = torch.clamp(linear_func, 0, 1)\n",
" return ramp_func\n",
"\n",
" # Base frequencies\n",
" pos_freqs = theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype) / head_dim)\n",
" inv_freq_extrapolation = 1.0 / pos_freqs # No scaling (extrapolation)\n",
" inv_freq_interpolation = 1.0 / (rope_factor * pos_freqs) # With scaling (interpolation)\n",
"\n",
" # Find the range where we blend between interpolation and extrapolation\n",
" low, high = find_correction_range(beta_fast, beta_slow, head_dim, theta_base, rope_orig_max)\n",
"\n",
" # Get n-dimensional rotational scaling corrected for extrapolation\n",
" inv_freq_extrapolation_factor = 1 - linear_ramp_factor(low, high, head_dim // 2).to(dtype=dtype)\n",
" inv_freq = (\n",
" inv_freq_interpolation * (1 - inv_freq_extrapolation_factor)\n",
" + inv_freq_extrapolation * inv_freq_extrapolation_factor\n",
" )\n",
" else:\n",
" # Default RoPE\n",
" inv_freq = 1.0 / (\n",
" theta_base ** (\n",
" torch.arange(0, head_dim, 2, dtype=dtype)[: head_dim // 2].float()\n",
" / head_dim\n",
" )\n",
" )\n",
" )\n",
"\n",
" # Generate position indices\n",
" positions = torch.arange(context_length, dtype=dtype)\n",
"\n",
" # Optional YaRN scaling\n",
" if rope_type == \"yarn\":\n",
" positions = positions / rope_factor\n",
" positions = torch.clamp(positions, max=rope_orig_max - 1)\n",
"\n",
" # Compute the base angles (shape: [context_length, head_dim // 2])\n",
" angles = positions.unsqueeze(1) * inv_freq.unsqueeze(0)\n",
"\n",
@@ -539,6 +574,8 @@
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
@@ -624,6 +661,8 @@
" \"rope_type\": \"yarn\",\n",
" \"rope_factor\": 8.0,\n",
" \"rope_orig_max\": 8_192,\n",
" \"beta_fast\": 32.0,\n",
" \"beta_slow\": 1.0,\n",
" \"rms_norm_eps\": 1e-6,\n",
" \"dtype\": torch.bfloat16,\n",
" \"eos_token_id\": 100_257,\n",
@@ -707,9 +746,9 @@
{
"data": {
"text/plain": [
"tensor([[[ 0.3594, -0.6289, -0.2754, ..., 1.1016, 0.4219, 0.0381],\n",
" [ 1.1719, 0.0283, 0.6055, ..., 0.4863, -0.1953, 0.2246],\n",
" [ 0.4902, -0.0425, 0.6758, ..., 0.3730, -0.5781, -0.1670]]],\n",
"tensor([[[ 0.3867, -0.6328, -0.2734, ..., 1.1484, 0.4258, 0.0400],\n",
" [ 1.2734, 0.0040, 0.5000, ..., 0.5625, -0.2383, 0.1855],\n",
" [ 0.5859, -0.0540, 0.7930, ..., 0.3262, -0.5430, -0.1494]]],\n",
" dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)"
]
},
@@ -1021,7 +1060,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 18,
"id": "7b6df8bc-7308-468e-93ce-2d5529ea7866",
"metadata": {},
"outputs": [
@@ -1031,7 +1070,7 @@
"'<|im_start|>user\\nGive me a short intro to large language models in 3 sentences.\\n<|im_end|>\\n<|im_start|>assistant\\n'"
]
},
"execution_count": 24,
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
@@ -1057,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 19,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@@ -1083,7 +1122,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 20,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@@ -1093,10 +1132,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Sure! Heres a brief introduction to large language models: \n",
"Large models are advanced AI systems trained to process vast neural networks capable of understanding and generating human-like text, learning from vast data. \n",
"They excel at many tasks across many languages and adapt to various tasks. \n",
"They power modern applications widely used in NLP solutions.\n",
"Large language models are advanced AI systems trained on vast amounts of text to understand and generate human-like language. They can perform a wide range of tasks, from answering questions to writing essays or code. These models have transformed natural language processing and are now integral to many modern technologies.\n",
"\n",
"GPU memory used: 13.70 GB\n"
]

View File

@@ -43,10 +43,51 @@ def tiny_debug_config():
}
def yarn_debug_config():
return {
"vocab_size": 257,
"context_length": 8,
"emb_dim": 32,
"n_heads": 4,
"n_layers": 2,
"hidden_dim": 64,
"head_dim": 8,
"qk_norm": True,
"n_kv_heads": 2,
"sliding_window": 4,
"layer_types": ["full_attention", "full_attention"],
"dtype": torch.float32,
"query_pre_attn_scalar": 256,
"attention_bias": False,
"rms_norm_eps": 1e-6,
"rope_base": 500_000.0,
"rope_attention_factor": 1.2079441541679836,
"rope_type": "yarn",
"rope_factor": 8.0,
"rope_orig_max": 8192,
"beta_fast": 32.0,
"beta_slow": 1.0,
"rope_local_base": 10_000.0,
}
def _hf_config_from_dict(cfg):
if Olmo3Config is None:
raise ImportError("transformers is required for the Olmo-3 debugger.")
rope_type = cfg.get("rope_type", "default")
rope_scaling = {"rope_type": rope_type}
if rope_type == "yarn":
rope_scaling.update(
{
"attention_factor": cfg.get("rope_attention_factor", 1.0),
"beta_fast": cfg.get("beta_fast", 32.0),
"beta_slow": cfg.get("beta_slow", 1.0),
"factor": cfg.get("rope_factor", 1.0),
"original_max_position_embeddings": cfg.get("rope_orig_max", 8192),
}
)
return Olmo3Config(
vocab_size=cfg["vocab_size"],
max_position_embeddings=cfg["context_length"],
@@ -64,7 +105,7 @@ def _hf_config_from_dict(cfg):
attn_implementation="eager",
torch_dtype=cfg.get("dtype", torch.float32),
query_pre_attn_scalar=cfg.get("query_pre_attn_scalar", 256),
rope_scaling={"rope_type": cfg.get("rope_type", "default")},
rope_scaling=rope_scaling,
qk_norm=cfg.get("qk_norm", False),
rms_norm_eps=cfg.get("rms_norm_eps", 1e-5),
)
@@ -231,7 +272,7 @@ if __name__ == "__main__":
raise SystemExit("transformers is not installed; install it to run the debugger.")
nb_imports = load_notebook_defs()
cfg = tiny_debug_config()
cfg = yarn_debug_config()
ours_model, hf_model = build_olmo3_pair(nb_imports, cfg)
torch.manual_seed(0)