mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Improve RoPE (#799)
This commit is contained in:
committed by
GitHub
parent
d87d91b23c
commit
70edd53809
@@ -71,7 +71,7 @@ def dummy_cfg_base():
|
|||||||
"n_kv_groups": 1,
|
"n_kv_groups": 1,
|
||||||
"qk_norm": False,
|
"qk_norm": False,
|
||||||
"dtype": torch.float32,
|
"dtype": torch.float32,
|
||||||
"rope_base": 10000,
|
"rope_base": 1000000,
|
||||||
"context_length": 64,
|
"context_length": 64,
|
||||||
"num_experts": 0,
|
"num_experts": 0,
|
||||||
}
|
}
|
||||||
@@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
@pytest.mark.skipif(not transformers_installed, reason="transformers not installed")
|
||||||
def test_rope():
|
@pytest.mark.parametrize("context_len", [1024, 8192, 40960])
|
||||||
|
def test_rope(context_len):
|
||||||
|
|
||||||
from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb
|
from transformers.models.qwen3.modeling_qwen3 import (
|
||||||
|
Qwen3RotaryEmbedding,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
)
|
||||||
|
|
||||||
# Settings
|
# Settings
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
context_len = 8192
|
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
head_dim = 16
|
head_dim = 16
|
||||||
rope_theta = 1_000_000
|
rope_theta = 1_000_000
|
||||||
|
|
||||||
# Instantiate RoPE parameters
|
# Instantiate RoPE parameters (our implementation)
|
||||||
cos, sin = compute_rope_params(
|
cos, sin = compute_rope_params(
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
theta_base=rope_theta,
|
theta_base=rope_theta,
|
||||||
@@ -166,7 +169,7 @@ def test_rope():
|
|||||||
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||||
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||||
|
|
||||||
# Apply rotary position embeddings
|
# Apply rotary embeddings with our implementation
|
||||||
queries_rot = apply_rope(queries, cos, sin)
|
queries_rot = apply_rope(queries, cos, sin)
|
||||||
keys_rot = apply_rope(keys, cos, sin)
|
keys_rot = apply_rope(keys, cos, sin)
|
||||||
|
|
||||||
@@ -176,7 +179,7 @@ def test_rope():
|
|||||||
factor = 1.0
|
factor = 1.0
|
||||||
dim: int = head_dim
|
dim: int = head_dim
|
||||||
rope_theta = 1_000_000
|
rope_theta = 1_000_000
|
||||||
max_position_embeddings: int = 8192
|
max_position_embeddings = context_len
|
||||||
hidden_size = head_dim * num_heads
|
hidden_size = head_dim * num_heads
|
||||||
num_attention_heads = num_heads
|
num_attention_heads = num_heads
|
||||||
|
|
||||||
@@ -187,10 +190,17 @@ def test_rope():
|
|||||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||||
|
|
||||||
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
# torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6)
|
||||||
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
# torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6)
|
||||||
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
|
||||||
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
# torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A
|
||||||
|
# torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6)
|
||||||
|
|
||||||
|
assert torch.equal(sin, ref_sin.squeeze(0))
|
||||||
|
assert torch.equal(cos, ref_cos.squeeze(0))
|
||||||
|
|
||||||
|
assert torch.equal(keys_rot, ref_keys_rot)
|
||||||
|
assert torch.equal(queries_rot, ref_queries_rot)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
|||||||
Reference in New Issue
Block a user