Sliding window KV Cache bug fix (#925)

1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length
2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset
3. Fix KV Cache import issue for gpt_with_kv_cache_optimized
This commit is contained in:
talentJay-ux
2025-12-15 16:47:01 -08:00
committed by GitHub
parent a11965fbd9
commit d7f178d28b
2 changed files with 117 additions and 3 deletions

View File

@@ -37,6 +37,12 @@ class MultiHeadAttention(nn.Module):
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape
if use_cache:
# to prevent self.ptr_cur became negative
assert num_tokens <= self.window_size, (
f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). "
)
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x)
queries = self.W_query(x)
@@ -221,6 +227,7 @@ class GPTModel(nn.Module):
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"]
def forward(self, in_idx, use_cache=False):
batch_size, seq_len = in_idx.shape
@@ -232,6 +239,12 @@ class GPTModel(nn.Module):
# NEW
if use_cache:
context_length = self.pos_emb.num_embeddings
# to prevent generate more sequence than context_length
# since longer than context_length will cause model out of bound error when reading the position embedding
assert self.ptr_current_pos + seq_len <= context_length, (
f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}"
)
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.ptr_current_pos += seq_len
else:
@@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u
model.eval()
ctx_len = context_size or model.pos_emb.num_embeddings
kv_window_size = model.kv_window_size
with torch.no_grad():
if use_cache:
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True)
input_tokens = idx[:, -ctx_len:]
input_tokens_length = input_tokens.size(1)
# prefill to handle input_tokens_length > kv_window_size
for i in range(0, input_tokens_length, kv_window_size):
chunk = input_tokens[:, i:i+kv_window_size]
logits = model(chunk, use_cache=True)
# can't generate more than ctx_len of result
# due to the limitation of position embedding
max_generable = ctx_len - input_tokens_length
max_new_tokens = min(max_new_tokens, max_generable)
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)

View File

@@ -9,7 +9,8 @@ from gpt_ch04 import generate_text_simple
from gpt_with_kv_cache import GPTModel as GPTModelKV1
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
from gpt_with_kv_cache import generate_text_simple_cached
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2
GPT_CONFIG_124M = {
@@ -20,6 +21,7 @@ GPT_CONFIG_124M = {
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 1024 # NEW: KV cache window size
}
@@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass):
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
elif ModelClass is GPTModelKV1:
token_ids = generate_text_simple_cachedKV1(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
else:
token_ids = generate_text_simple_cached(
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
@@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass):
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)
def test_context_overflow_bug():
"""
Test that demonstrates the ptr_current_pos overflow bug.
In old implementation:
- context_length = 10 (positions 0-9 available)
- We try to generate 15 tokens total (5 input + 10 generated)
- At token 11 (position 10), it crashes trying to access pos_emb[10]
"""
GPT_CONFIG_SMALL = {
"vocab_size": 50257,
"context_length": 10, # Very small context
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 20 # Larger than context_length
}
torch.manual_seed(123)
model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
model.eval()
# 5 input tokens
input_tokens = torch.randint(0, 50257, (1, 5), device=device)
generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
context_size=GPT_CONFIG_SMALL["context_length"],
use_cache=True
)
def test_prefill_chunking_basic():
"""
Test that prefill correctly chunks input when input_length > kv_window_size.
Setup:
- kv_window_size = 4
- input_length = 10
- Should process in 3 chunks: [0:4], [4:8], [8:10]
"""
config = {
"vocab_size": 50257,
"context_length": 20,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 4 # Small window to force chunking
}
torch.manual_seed(123)
model = GPTModelKV2(config).to(device)
model.eval()
# 10 input tokens (> kv_window_size of 4)
input_tokens = torch.randint(0, 50257, (1, 10), device=device)
# Should successfully process all input in chunks
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=2,
use_cache=True
)
# Should have 10 input + 2 generated = 12 total
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"
# First 10 tokens should match input
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"