Improve rope settings for llama3 (#380)

This commit is contained in:
Sebastian Raschka
2024-10-03 08:29:54 -05:00
committed by GitHub
parent 278a50a348
commit b993c2b25b
3 changed files with 57 additions and 24 deletions

View File

@@ -58,10 +58,10 @@ def set_seed():
torch.manual_seed(123)
def test_rope(notebook):
def test_rope_llama2(notebook):
# Settings
batch_size = 1
context_len = 5
context_len = 4096
num_heads = 4
head_dim = 16
@@ -76,19 +76,51 @@ def test_rope(notebook):
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
class RoPEConfig:
rope_type = "default"
rope_scaling = None
factor = 1.0
dim: int = head_dim
rope_theta = 10000
max_position_embeddings: int = 4096
hidden_size = head_dim * num_heads
num_attention_heads = num_heads
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=10_000
)
config = RoPEConfig()
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
torch.testing.assert_close(sin, ref_sin.squeeze(0))
torch.testing.assert_close(cos, ref_cos.squeeze(0))
torch.testing.assert_close(keys_rot, ref_keys_rot)
torch.testing.assert_close(queries_rot, ref_queries_rot)
def test_rope_llama3(notebook):
# Settings
batch_size = 1
context_len = 8192
num_heads = 4
head_dim = 16
theta_base = 50_000
# Instantiate RoPE parameters
cos, sin = notebook.precompute_rope_params(
head_dim=head_dim,
context_length=context_len,
theta_base=theta_base
)
# Dummy query and key tensors
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
# Apply rotary position embeddings
queries_rot = notebook.compute_rope(queries, cos, sin)
keys_rot = notebook.compute_rope(keys, cos, sin)
rot_emb = LlamaRotaryEmbedding(
dim=head_dim,
max_position_embeddings=context_len,
base=theta_base
)
rot_emb = LlamaRotaryEmbedding(config=config)
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
ref_cos, ref_sin = rot_emb(queries, position_ids)
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
@@ -108,7 +140,7 @@ def test_silu(notebook):
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
def test_rmsnorm(notebook):
example_batch = torch.randn(2, 3, 4)
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))