mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve RoPE (#799)
This commit is contained in:
committed by
GitHub
parent
d87d91b23c
commit
70edd53809
@@ -71,7 +71,7 @@ def dummy_cfg_base():
|
||||
"n_kv_groups": 1,
|
||||
"qk_norm": False,
|
||||
"dtype": torch.float32,
|
||||
"rope_base": 10000,
|
||||
"rope_base": 1000000,
|
||||
"context_length": 64,
|
||||
"num_experts": 0,
|
||||
}
|
||||
@@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||
def test_rope():
|
||||
@pytest.mark.parametrize("context_len", [1024, 8192, 40960])
|
||||
def test_rope(context_len):
|
||||
|
||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
|
||||
from transformers.models.qwen3.modeling_qwen3 import (
|
||||
Qwen3RotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
)
|
||||
|
||||
# Settings
|
||||
batch_size = 1
|
||||
context_len = 8192
|
||||
num_heads = 4
|
||||
head_dim = 16
|
||||
rope_theta = 1_000_000
|
||||
|
||||
# Instantiate RoPE parameters
|
||||
# Instantiate RoPE parameters (our implementation)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=head_dim,
|
||||
theta_base=rope_theta,
|
||||
@@ -166,7 +169,7 @@ def test_rope():
|
||||
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||
|
||||
# Apply rotary position embeddings
|
||||
# Apply rotary embeddings with our implementation
|
||||
queries_rot = apply_rope(queries, cos, sin)
|
||||
keys_rot = apply_rope(keys, cos, sin)
|
||||
|
||||
@@ -176,7 +179,7 @@ def test_rope():
|
||||
factor = 1.0
|
||||
dim: int = head_dim
|
||||
rope_theta = 1_000_000
|
||||
max_position_embeddings: int = 8192
|
||||
max_position_embeddings = context_len
|
||||
hidden_size = head_dim * num_heads
|
||||
num_attention_heads = num_heads
|
||||
|
||||
@@ -187,10 +190,17 @@ def test_rope():
|
||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||
|
||||
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
||||
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
||||
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
||||
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
||||
# torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6)
|
||||
# torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6)
|
||||
|
||||
# torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A
|
||||
# torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6)
|
||||
|
||||
assert torch.equal(sin, ref_sin.squeeze(0))
|
||||
assert torch.equal(cos, ref_cos.squeeze(0))
|
||||
|
||||
assert torch.equal(keys_rot, ref_keys_rot)
|
||||
assert torch.equal(queries_rot, ref_queries_rot)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
||||
Reference in New Issue
Block a user