Llama 3 KV Cache (#685)

* Llama 3 KV Cache

* skip expensive tests on Gh actions

* Update __init__.py
This commit is contained in:
Sebastian Raschka
2025-06-21 10:55:20 -05:00
committed by GitHub
parent 9f6f514191
commit 9d62ca0598
7 changed files with 410 additions and 4 deletions

View File

@@ -124,6 +124,9 @@ from llms_from_scratch.llama3 import (
ChatFormat,
clean_text
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model
from llms_from_scratch.kv_cache.generate import generate_text_simple
```
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

View File

@@ -0,0 +1,4 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch

View File

@@ -0,0 +1,29 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import torch
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.cfg["context_length"]
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx

View File

@@ -0,0 +1,317 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401
import torch
import torch.nn as nn
LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 500_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
"rope_freq": { # RoPE frequency scaling
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
}
LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers
"hidden_dim": 8192, # Size of the intermediate dimension in FeedForward
"n_kv_groups": 8, # Key-Value groups for grouped-query attention
"rope_base": 500_000.0, # The base in RoPE's "theta"
"dtype": torch.bfloat16, # Lower-precision dtype to reduce memory usage
"rope_freq": { # RoPE frequency scaling
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": 8192,
}
}
class Llama3Model(nn.Module):
def __init__(self, cfg):
super().__init__()
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
freq_config=cfg["rope_freq"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
def forward(self, in_idx, use_cache=False):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin, use_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin, use_cache=False):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False)
self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False)
def forward(self, x):
x_fc1 = self.fc1(x)
x_fc2 = self.fc2(x)
x = nn.functional.silu(x_fc1) * x_fc2
return self.fc3(x)
class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None
):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
# For optional KV cache
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False):
b, num_tokens, d_in = x.shape
queries = self.W_query(x) # Shape: (b, num_tokens, d_out)
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)
# Reshape queries, keys, and values
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice)
queries = apply_rope(queries, cos_slice, sin_slice)
# Expand keys and values to match the number of heads
# Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
# For example, before repeat_interleave along dim=1 (query groups):
# [K1, K2]
# After repeat_interleave (each query group is repeated group_size times):
# [K1, K1, K2, K2]
# If we used regular repeat instead of repeat_interleave, we'd get:
# [K1, K2, K1, K2]
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else:
keys, values = keys_new, values_new
# Compute scaled dot-product attention (aka self-attention) with a causal mask
# Shape: (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
assert keys.shape[-1] == self.head_dim
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
assert head_dim % 2 == 0, "Embedding dimension must be even"
# Compute the inverse frequencies
inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))
# Frequency adjustments
if freq_config is not None:
low_freq_wavelen = freq_config["original_context_length"] / freq_config["low_freq_factor"]
high_freq_wavelen = freq_config["original_context_length"] / freq_config["high_freq_factor"]
wavelen = 2 * torch.pi / inv_freq
inv_freq_llama = torch.where(
wavelen > low_freq_wavelen, inv_freq / freq_config["factor"], inv_freq
)
smooth_factor = (freq_config["original_context_length"] / wavelen - freq_config["low_freq_factor"]) / (
freq_config["high_freq_factor"] - freq_config["low_freq_factor"]
)
smoothed_inv_freq = (
(1 - smooth_factor) * (inv_freq / freq_config["factor"]) + smooth_factor * inv_freq
)
is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
inv_freq = inv_freq_llama
# Generate position indices
positions = torch.arange(context_length, dtype=dtype)
# Compute the angles
angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)
# Expand angles to match the head_dim
angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)
# Precompute sine and cosine
cos = torch.cos(angles)
sin = torch.sin(angles)
return cos, sin
def apply_rope(x, cos, sin):
# x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even"
# Split x into first half and second half
x1 = x[..., : head_dim // 2] # First half
x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1)
x_rotated = (x * cos) + (rotated * sin)
# It's ok to use lower-precision after applying cos and sin rotation
return x_rotated.to(dtype=x.dtype)

View File

@@ -12,8 +12,11 @@ from llms_from_scratch.llama3 import (
GroupedQueryAttentionFast,
Llama3Model,
)
from llms_from_scratch.kv_cache.llama3 import Llama3Model as Llama3ModelKV
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
import importlib
import os
import pytest
import tiktoken
import torch
@@ -180,8 +183,20 @@ def llama3_weights_path(tmp_path_factory):
return path
@pytest.mark.parametrize("ModelClass", [Llama3Model])
def test_gpt_model_variants(ModelClass, llama3_weights_path):
@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="Skipping in GitHub Actions due to compute or memory constraints"
)
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path))
@@ -198,7 +213,7 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
out = generate_fn(
model=model,
idx=encoded_tensor,
max_new_tokens=5,