From d7f178d28bdf56732b01544d5718bfed637ac6f9 Mon Sep 17 00:00:00 2001 From: talentJay-ux Date: Mon, 15 Dec 2025 16:47:01 -0800 Subject: [PATCH] Sliding window KV Cache bug fix (#925) 1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length 2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset 3. Fix KV Cache import issue for gpt_with_kv_cache_optimized --- .../gpt_with_kv_cache_optimized.py | 28 +++++- ch04/03_kv-cache/tests.py | 92 ++++++++++++++++++- 2 files changed, 117 insertions(+), 3 deletions(-) diff --git a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py index 745cac6..98f15a6 100644 --- a/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py +++ b/ch04/03_kv-cache/gpt_with_kv_cache_optimized.py @@ -37,6 +37,12 @@ class MultiHeadAttention(nn.Module): def forward(self, x, use_cache=False): b, num_tokens, d_in = x.shape + if use_cache: + # to prevent self.ptr_cur became negative + assert num_tokens <= self.window_size, ( + f"Input chunk size ({num_tokens}) exceeds KV cache window size ({self.window_size}). " + ) + keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) values_new = self.W_value(x) queries = self.W_query(x) @@ -221,6 +227,7 @@ class GPTModel(nn.Module): self.final_norm = LayerNorm(cfg["emb_dim"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) + self.kv_window_size = cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] def forward(self, in_idx, use_cache=False): batch_size, seq_len = in_idx.shape @@ -232,6 +239,12 @@ class GPTModel(nn.Module): # NEW if use_cache: + context_length = self.pos_emb.num_embeddings + # to prevent generate more sequence than context_length + # since longer than context_length will cause model out of bound error when reading the position embedding + assert self.ptr_current_pos + seq_len <= context_length, ( + f"Position embedding overflow. Want to read {self.ptr_current_pos + seq_len} which excceded size of {context_length}" + ) pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long) self.ptr_current_pos += seq_len else: @@ -294,11 +307,24 @@ def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, u model.eval() ctx_len = context_size or model.pos_emb.num_embeddings + kv_window_size = model.kv_window_size with torch.no_grad(): if use_cache: model.reset_kv_cache() - logits = model(idx[:, -ctx_len:], use_cache=True) + + input_tokens = idx[:, -ctx_len:] + input_tokens_length = input_tokens.size(1) + + # prefill to handle input_tokens_length > kv_window_size + for i in range(0, input_tokens_length, kv_window_size): + chunk = input_tokens[:, i:i+kv_window_size] + logits = model(chunk, use_cache=True) + + # can't generate more than ctx_len of result + # due to the limitation of position embedding + max_generable = ctx_len - input_tokens_length + max_new_tokens = min(max_new_tokens, max_generable) for _ in range(max_new_tokens): next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) diff --git a/ch04/03_kv-cache/tests.py b/ch04/03_kv-cache/tests.py index 83aae44..d2452b5 100644 --- a/ch04/03_kv-cache/tests.py +++ b/ch04/03_kv-cache/tests.py @@ -9,7 +9,8 @@ from gpt_ch04 import generate_text_simple from gpt_with_kv_cache import GPTModel as GPTModelKV1 from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2 -from gpt_with_kv_cache import generate_text_simple_cached +from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1 +from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2 GPT_CONFIG_124M = { @@ -20,6 +21,7 @@ GPT_CONFIG_124M = { "n_layers": 12, "drop_rate": 0.1, "qkv_bias": False, + "kv_window_size": 1024 # NEW: KV cache window size } @@ -80,8 +82,15 @@ def test_gpt_model_equivalence_cached(ModelClass): max_new_tokens=30, context_size=GPT_CONFIG_124M["context_length"] ) + elif ModelClass is GPTModelKV1: + token_ids = generate_text_simple_cachedKV1( + model=model, + idx=encoded_tensor, + max_new_tokens=30, + context_size=GPT_CONFIG_124M["context_length"] + ) else: - token_ids = generate_text_simple_cached( + token_ids = generate_text_simple_cachedKV2( model=model, idx=encoded_tensor, max_new_tokens=30, @@ -99,3 +108,82 @@ def test_gpt_model_equivalence_cached(ModelClass): assert torch.equal(base_output, other_output), ( f"Mismatch between {base_name} and {other_name}" ) + + +def test_context_overflow_bug(): + """ + Test that demonstrates the ptr_current_pos overflow bug. + + In old implementation: + - context_length = 10 (positions 0-9 available) + - We try to generate 15 tokens total (5 input + 10 generated) + - At token 11 (position 10), it crashes trying to access pos_emb[10] + """ + GPT_CONFIG_SMALL = { + "vocab_size": 50257, + "context_length": 10, # Very small context + "emb_dim": 768, + "n_heads": 12, + "n_layers": 12, + "drop_rate": 0.1, + "qkv_bias": False, + "kv_window_size": 20 # Larger than context_length + } + + torch.manual_seed(123) + + model = GPTModelKV2(GPT_CONFIG_SMALL).to(device) + model.eval() + + # 5 input tokens + input_tokens = torch.randint(0, 50257, (1, 5), device=device) + + generate_text_simple_cachedKV2( + model=model, + idx=input_tokens, + max_new_tokens=10, # 5 + 10 = 15 > 10 context_length + context_size=GPT_CONFIG_SMALL["context_length"], + use_cache=True + ) + + +def test_prefill_chunking_basic(): + """ + Test that prefill correctly chunks input when input_length > kv_window_size. + + Setup: + - kv_window_size = 4 + - input_length = 10 + - Should process in 3 chunks: [0:4], [4:8], [8:10] + """ + config = { + "vocab_size": 50257, + "context_length": 20, + "emb_dim": 768, + "n_heads": 12, + "n_layers": 12, + "drop_rate": 0.1, + "qkv_bias": False, + "kv_window_size": 4 # Small window to force chunking + } + + torch.manual_seed(123) + model = GPTModelKV2(config).to(device) + model.eval() + + # 10 input tokens (> kv_window_size of 4) + input_tokens = torch.randint(0, 50257, (1, 10), device=device) + + # Should successfully process all input in chunks + token_ids = generate_text_simple_cachedKV2( + model=model, + idx=input_tokens, + max_new_tokens=2, + use_cache=True + ) + + # Should have 10 input + 2 generated = 12 total + assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}" + + # First 10 tokens should match input + assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved" \ No newline at end of file