diff --git a/pkg/llms_from_scratch/tests/test_qwen3.py b/pkg/llms_from_scratch/tests/test_qwen3.py index 82956de..b64ea8e 100644 --- a/pkg/llms_from_scratch/tests/test_qwen3.py +++ b/pkg/llms_from_scratch/tests/test_qwen3.py @@ -71,7 +71,7 @@ def dummy_cfg_base(): "n_kv_groups": 1, "qk_norm": False, "dtype": torch.float32, - "rope_base": 10000, + "rope_base": 1000000, "context_length": 64, "num_experts": 0, } @@ -143,18 +143,21 @@ def test_qwen3_kvcache_equivalence(cfg_name, request): @pytest.mark.skipif(not transformers_installed, reason="transformers not installed") -def test_rope(): +@pytest.mark.parametrize("context_len", [1024, 8192, 40960]) +def test_rope(context_len): - from transformers.models.qwen3.modeling_qwen3 import Qwen3RotaryEmbedding, apply_rotary_pos_emb + from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3RotaryEmbedding, + apply_rotary_pos_emb, + ) # Settings batch_size = 1 - context_len = 8192 num_heads = 4 head_dim = 16 rope_theta = 1_000_000 - # Instantiate RoPE parameters + # Instantiate RoPE parameters (our implementation) cos, sin = compute_rope_params( head_dim=head_dim, theta_base=rope_theta, @@ -166,7 +169,7 @@ def test_rope(): queries = torch.randn(batch_size, num_heads, context_len, head_dim) keys = torch.randn(batch_size, num_heads, context_len, head_dim) - # Apply rotary position embeddings + # Apply rotary embeddings with our implementation queries_rot = apply_rope(queries, cos, sin) keys_rot = apply_rope(keys, cos, sin) @@ -176,7 +179,7 @@ def test_rope(): factor = 1.0 dim: int = head_dim rope_theta = 1_000_000 - max_position_embeddings: int = 8192 + max_position_embeddings = context_len hidden_size = head_dim * num_heads num_attention_heads = num_heads @@ -187,10 +190,17 @@ def test_rope(): ref_cos, ref_sin = rot_emb(queries, position_ids) ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) - torch.testing.assert_close(sin, ref_sin.squeeze(0)) - torch.testing.assert_close(cos, ref_cos.squeeze(0)) - torch.testing.assert_close(keys_rot, ref_keys_rot) - torch.testing.assert_close(queries_rot, ref_queries_rot) + # torch.testing.assert_close(sin, ref_sin.squeeze(0), rtol=1e-5, atol=1e-6) + # torch.testing.assert_close(cos, ref_cos.squeeze(0), rtol=1e-5, atol=1e-6) + + # torch.testing.assert_close(keys_rot, ref_keys_rot, rtol=1e-5, atol=1e-6)A + # torch.testing.assert_close(queries_rot, ref_queries_rot, rtol=1e-5, atol=1e-6) + + assert torch.equal(sin, ref_sin.squeeze(0)) + assert torch.equal(cos, ref_cos.squeeze(0)) + + assert torch.equal(keys_rot, ref_keys_rot) + assert torch.equal(queries_rot, ref_queries_rot) @pytest.fixture(scope="session")