mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Llama3Fast (#593)
* Llama3Fast * Update pkg/llms_from_scratch/tests/test_llama3.py
This commit is contained in:
committed by
GitHub
parent
4128a91c1d
commit
2dc2df593a
@@ -67,7 +67,10 @@ class Llama3Model(nn.Module):
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
self.register_buffer("mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool())
|
||||
self.register_buffer(
|
||||
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
|
||||
persistent=False
|
||||
)
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
@@ -86,7 +89,6 @@ class Llama3Model(nn.Module):
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx):
|
||||
# Forward pass
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
@@ -143,9 +145,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(
|
||||
self, d_in, d_out, num_heads,
|
||||
num_kv_groups,
|
||||
dtype=None
|
||||
self, d_in, d_out, num_heads, num_kv_groups, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
@@ -375,3 +375,136 @@ def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
||||
else:
|
||||
# If the token is not found, return the original text
|
||||
return text
|
||||
|
||||
|
||||
######################################################################
|
||||
# Llama 3 fast (alternative code geared towards efficiency)
|
||||
######################################################################
|
||||
|
||||
class GroupedQueryAttentionFast(nn.Module):
|
||||
"""
|
||||
Drop-in replacement for GroupedQueryAttention but using PyTorch's
|
||||
scaled_dot_product_attention, which uses FlashAttention if run
|
||||
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
|
||||
"""
|
||||
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
|
||||
super().__init__()
|
||||
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
|
||||
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
|
||||
|
||||
self.d_out = d_out
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_out // num_heads
|
||||
self.num_kv_groups = num_kv_groups
|
||||
self.group_size = num_heads // num_kv_groups
|
||||
|
||||
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
|
||||
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
|
||||
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
|
||||
|
||||
def forward(self, x, cos, sin):
|
||||
b, num_tokens, _ = x.shape
|
||||
|
||||
# Project to queries, keys, values
|
||||
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
|
||||
|
||||
# Apply Rotary Positional Embedding
|
||||
q = apply_rope(q, cos, sin)
|
||||
k = apply_rope(k, cos, sin)
|
||||
|
||||
# Expand key/value groups to full head count
|
||||
k = k.repeat_interleave(self.group_size, dim=1)
|
||||
v = v.repeat_interleave(self.group_size, dim=1)
|
||||
|
||||
# Efficient scaled dot-product attention
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
is_causal=True # Enables Flash/FlexAttention kernels
|
||||
)
|
||||
|
||||
# Combine heads and project
|
||||
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
|
||||
return self.out_proj(attn_output)
|
||||
|
||||
|
||||
class TransformerBlockFast(nn.Module):
|
||||
"""
|
||||
Same as original TransformerBlock but uses
|
||||
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||
"""
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.att = GroupedQueryAttentionFast(
|
||||
d_in=cfg["emb_dim"],
|
||||
d_out=cfg["emb_dim"],
|
||||
num_heads=cfg["n_heads"],
|
||||
num_kv_groups=cfg["n_kv_groups"],
|
||||
dtype=cfg["dtype"]
|
||||
)
|
||||
self.ff = FeedForward(cfg)
|
||||
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
|
||||
def forward(self, x, cos, sin):
|
||||
# Shortcut connection for attention block
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
# Shortcut connection for feed-forward block
|
||||
shortcut = x
|
||||
x = self.norm2(x)
|
||||
x = self.ff(x)
|
||||
x = x + shortcut # Add the original input back
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Llama3ModelFast(nn.Module):
|
||||
"""
|
||||
Same as original Llama3Model but uses TransformerBlockFast
|
||||
instead of TransformerBlock, which in turn uses
|
||||
GroupedQueryAttentionFast instead of GroupedQueryAttention.
|
||||
"""
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
# Main model parameters
|
||||
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
|
||||
|
||||
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
|
||||
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
|
||||
)
|
||||
|
||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
cfg["rope_base"],
|
||||
cfg["orig_context_length"],
|
||||
cfg["context_length"]
|
||||
)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
context_length=cfg["context_length"],
|
||||
freq_config=cfg["rope_freq"]
|
||||
)
|
||||
self.register_buffer("cos", cos, persistent=False)
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
self.cfg = cfg
|
||||
|
||||
def forward(self, in_idx):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.cos, self.sin)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
|
||||
Reference in New Issue
Block a user