mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Llama3Fast (#593)
* Llama3Fast * Update pkg/llms_from_scratch/tests/test_llama3.py
This commit is contained in:
committed by
GitHub
parent
4128a91c1d
commit
2dc2df593a
@@ -9,7 +9,9 @@ from llms_from_scratch.llama3 import (
|
||||
apply_rope,
|
||||
rescale_theta,
|
||||
LLAMA32_CONFIG_1B,
|
||||
Llama3Model
|
||||
GroupedQueryAttention,
|
||||
GroupedQueryAttentionFast,
|
||||
Llama3Model,
|
||||
)
|
||||
|
||||
import importlib
|
||||
@@ -117,13 +119,63 @@ def test_rescale():
|
||||
assert old_theta == 500_000.
|
||||
|
||||
|
||||
def test_grouped_query_attention_equivalence():
|
||||
torch.manual_seed(42)
|
||||
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
|
||||
|
||||
x = torch.randn(b, t, d_in)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=d_out // num_heads,
|
||||
theta_base=50_000,
|
||||
context_length=t,
|
||||
freq_config={
|
||||
"factor": 32.0,
|
||||
"low_freq_factor": 1.0,
|
||||
"high_freq_factor": 4.0,
|
||||
"original_context_length": t,
|
||||
}
|
||||
)
|
||||
|
||||
# Causal mask for the slow version
|
||||
mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1)
|
||||
|
||||
attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups)
|
||||
attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups)
|
||||
|
||||
# Copy weights to make both models identical
|
||||
attn2.load_state_dict(attn1.state_dict())
|
||||
|
||||
# Run both
|
||||
y1 = attn1(x, mask, cos, sin)
|
||||
y2 = attn2(x, cos, sin)
|
||||
|
||||
# Compare outputs
|
||||
max_diff = (y1 - y2).abs().max().item()
|
||||
print(f"Max difference between slow and fast outputs: {max_diff:.4e}")
|
||||
assert torch.allclose(y1, y2, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def llama3_weights_path(tmp_path_factory):
|
||||
"""Creates and saves a deterministic Llama3 model for testing."""
|
||||
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"
|
||||
|
||||
if not path.exists():
|
||||
torch.manual_seed(123)
|
||||
model = Llama3Model(LLAMA32_CONFIG_1B)
|
||||
torch.save(model.state_dict(), path)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Llama3Model])
|
||||
def test_gpt_model_variants(ModelClass):
|
||||
def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(LLAMA32_CONFIG_1B)
|
||||
model.load_state_dict(torch.load(llama3_weights_path))
|
||||
model.eval()
|
||||
|
||||
start_context = "Hello, I am"
|
||||
start_context = "Llamas eat"
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
encoded = tokenizer.encode(start_context)
|
||||
@@ -137,11 +189,11 @@ def test_gpt_model_variants(ModelClass):
|
||||
out = generate_text_simple(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=10,
|
||||
max_new_tokens=5,
|
||||
context_size=LLAMA32_CONFIG_1B["context_length"]
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[15496, 11, 314, 716, 78563, 89362, 19616, 115725, 114917,
|
||||
97198, 60342, 19108, 100752, 98969]
|
||||
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
Reference in New Issue
Block a user