mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Sliding window KV Cache bug fix (#925)
1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length 2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset 3. Fix KV Cache import issue for gpt_with_kv_cache_optimized
This commit is contained in:
@@ -37,6 +37,12 @@ class MultiHeadAttention(nn.Module):
|
|||||||
def forward(self, x, use_cache=False):
|
def forward(self, x, use_cache=False):
|
||||||
b, num_tokens, d_in = x.shape
|
b, num_tokens, d_in = x.shape
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
# to prevent self.ptr_cur became negative
|
||||||
|
assert num_tokens <= self.window_size, (
|
||||||
|
f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
|
||||||
|
)
|
||||||
|
|
||||||
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
|
||||||
values_new = self.W_value(x)
|
values_new = self.W_value(x)
|
||||||
queries = self.W_query(x)
|
queries = self.W_query(x)
|
||||||
@@ -221,6 +227,7 @@ class GPTModel(nn.Module):
|
|||||||
|
|
||||||
self.final_norm = LayerNorm(cfg["emb_dim"])
|
self.final_norm = LayerNorm(cfg["emb_dim"])
|
||||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
|
||||||
|
self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]
|
||||||
|
|
||||||
def forward(self, in_idx, use_cache=False):
|
def forward(self, in_idx, use_cache=False):
|
||||||
batch_size, seq_len = in_idx.shape
|
batch_size, seq_len = in_idx.shape
|
||||||
@@ -232,6 +239,12 @@ class GPTModel(nn.Module):
|
|||||||
# NEW
|
# NEW
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
|
context_length = self.pos_emb.num_embeddings
|
||||||
|
# to prevent generate more sequence than context_length
|
||||||
|
# since longer than context_length will cause model out of bound error when reading the position embedding
|
||||||
|
assert self.ptr_current_pos + seq_len <= context_length, (
|
||||||
|
f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}"
|
||||||
|
)
|
||||||
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
|
||||||
self.ptr_current_pos += seq_len
|
self.ptr_current_pos += seq_len
|
||||||
else:
|
else:
|
||||||
@@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
ctx_len = context_size or model.pos_emb.num_embeddings
|
ctx_len = context_size or model.pos_emb.num_embeddings
|
||||||
|
kv_window_size = model.kv_window_size
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if use_cache:
|
if use_cache:
|
||||||
model.reset_kv_cache()
|
model.reset_kv_cache()
|
||||||
logits = model(idx[:, -ctx_len:], use_cache=True)
|
|
||||||
|
input_tokens = idx[:, -ctx_len:]
|
||||||
|
input_tokens_length = input_tokens.size(1)
|
||||||
|
|
||||||
|
# prefill to handle input_tokens_length > kv_window_size
|
||||||
|
for i in range(0, input_tokens_length, kv_window_size):
|
||||||
|
chunk = input_tokens[:, i:i+kv_window_size]
|
||||||
|
logits = model(chunk, use_cache=True)
|
||||||
|
|
||||||
|
# can't generate more than ctx_len of result
|
||||||
|
# due to the limitation of position embedding
|
||||||
|
max_generable = ctx_len - input_tokens_length
|
||||||
|
max_new_tokens = min(max_new_tokens, max_generable)
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ from gpt_ch04 import generate_text_simple
|
|||||||
|
|
||||||
from gpt_with_kv_cache import GPTModel as GPTModelKV1
|
from gpt_with_kv_cache import GPTModel as GPTModelKV1
|
||||||
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
|
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
|
||||||
from gpt_with_kv_cache import generate_text_simple_cached
|
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
|
||||||
|
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2
|
||||||
|
|
||||||
|
|
||||||
GPT_CONFIG_124M = {
|
GPT_CONFIG_124M = {
|
||||||
@@ -20,6 +21,7 @@ GPT_CONFIG_124M = {
|
|||||||
"n_layers": 12,
|
"n_layers": 12,
|
||||||
"drop_rate": 0.1,
|
"drop_rate": 0.1,
|
||||||
"qkv_bias": False,
|
"qkv_bias": False,
|
||||||
|
"kv_window_size": 1024 # NEW: KV cache window size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass):
|
|||||||
max_new_tokens=30,
|
max_new_tokens=30,
|
||||||
context_size=GPT_CONFIG_124M["context_length"]
|
context_size=GPT_CONFIG_124M["context_length"]
|
||||||
)
|
)
|
||||||
|
elif ModelClass is GPTModelKV1:
|
||||||
|
token_ids = generate_text_simple_cachedKV1(
|
||||||
|
model=model,
|
||||||
|
idx=encoded_tensor,
|
||||||
|
max_new_tokens=30,
|
||||||
|
context_size=GPT_CONFIG_124M["context_length"]
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
token_ids = generate_text_simple_cached(
|
token_ids = generate_text_simple_cachedKV2(
|
||||||
model=model,
|
model=model,
|
||||||
idx=encoded_tensor,
|
idx=encoded_tensor,
|
||||||
max_new_tokens=30,
|
max_new_tokens=30,
|
||||||
@@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass):
|
|||||||
assert torch.equal(base_output, other_output), (
|
assert torch.equal(base_output, other_output), (
|
||||||
f"Mismatch between {base_name} and {other_name}"
|
f"Mismatch between {base_name} and {other_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_context_overflow_bug():
|
||||||
|
"""
|
||||||
|
Test that demonstrates the ptr_current_pos overflow bug.
|
||||||
|
|
||||||
|
In old implementation:
|
||||||
|
- context_length = 10 (positions 0-9 available)
|
||||||
|
- We try to generate 15 tokens total (5 input + 10 generated)
|
||||||
|
- At token 11 (position 10), it crashes trying to access pos_emb[10]
|
||||||
|
"""
|
||||||
|
GPT_CONFIG_SMALL = {
|
||||||
|
"vocab_size": 50257,
|
||||||
|
"context_length": 10, # Very small context
|
||||||
|
"emb_dim": 768,
|
||||||
|
"n_heads": 12,
|
||||||
|
"n_layers": 12,
|
||||||
|
"drop_rate": 0.1,
|
||||||
|
"qkv_bias": False,
|
||||||
|
"kv_window_size": 20 # Larger than context_length
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 5 input tokens
|
||||||
|
input_tokens = torch.randint(0, 50257, (1, 5), device=device)
|
||||||
|
|
||||||
|
generate_text_simple_cachedKV2(
|
||||||
|
model=model,
|
||||||
|
idx=input_tokens,
|
||||||
|
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
|
||||||
|
context_size=GPT_CONFIG_SMALL["context_length"],
|
||||||
|
use_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefill_chunking_basic():
|
||||||
|
"""
|
||||||
|
Test that prefill correctly chunks input when input_length > kv_window_size.
|
||||||
|
|
||||||
|
Setup:
|
||||||
|
- kv_window_size = 4
|
||||||
|
- input_length = 10
|
||||||
|
- Should process in 3 chunks: [0:4], [4:8], [8:10]
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"vocab_size": 50257,
|
||||||
|
"context_length": 20,
|
||||||
|
"emb_dim": 768,
|
||||||
|
"n_heads": 12,
|
||||||
|
"n_layers": 12,
|
||||||
|
"drop_rate": 0.1,
|
||||||
|
"qkv_bias": False,
|
||||||
|
"kv_window_size": 4 # Small window to force chunking
|
||||||
|
}
|
||||||
|
|
||||||
|
torch.manual_seed(123)
|
||||||
|
model = GPTModelKV2(config).to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# 10 input tokens (> kv_window_size of 4)
|
||||||
|
input_tokens = torch.randint(0, 50257, (1, 10), device=device)
|
||||||
|
|
||||||
|
# Should successfully process all input in chunks
|
||||||
|
token_ids = generate_text_simple_cachedKV2(
|
||||||
|
model=model,
|
||||||
|
idx=input_tokens,
|
||||||
|
max_new_tokens=2,
|
||||||
|
use_cache=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have 10 input + 2 generated = 12 total
|
||||||
|
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"
|
||||||
|
|
||||||
|
# First 10 tokens should match input
|
||||||
|
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"
|
||||||
Reference in New Issue
Block a user