Improve RoPE (#799)

This commit is contained in:
Sebastian Raschka
2025-08-31 11:46:36 -05:00
committed by GitHub
parent d87d91b23c
commit 70edd53809

View File

@@ -71,7 +71,7 @@ def dummy_cfg_base():
"n_kv_groups": 1,
"qk_norm": False,
"dtype": torch.float32,
"rope_base": 10000,
"rope_base": 1000000,
"context_length": 64,
"num_experts": 0,
}
@@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
def test_rope():
@pytest.mark.parametrize("context_len", [1024, 8192, 40960])
def test_rope(context_len):
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
)
# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
rope_theta = 1_000_000
# Instantiate RoPE parameters
# Instantiate RoPE parameters (our implementation)
cos, sin = compute_rope_params(
head_dim=head_dim,
theta_base=rope_theta,
@@ -166,7 +169,7 @@ def test_rope():
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
# Apply rotary embeddings with our implementation
queries_rot = apply_rope(queries, cos, sin)
keys_rot = apply_rope(keys, cos, sin)
@@ -176,7 +179,7 @@ def test_rope():
factor = 1.0
dim: int = head_dim
rope_theta = 1_000_000
max_position_embeddings: int = 8192
max_position_embeddings = context_len
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
@@ -187,10 +190,17 @@ def test_rope():
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)
# torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6)
# torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6)
# torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A
# torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6)
assert torch.equal(sin, ref_sin.squeeze(0))
assert torch.equal(cos, ref_cos.squeeze(0))
assert torch.equal(keys_rot, ref_keys_rot)
assert torch.equal(queries_rot, ref_queries_rot)
@pytest.fixture(scope="session")