Files
talentJay-ux d7f178d28b Sliding window KV Cache bug fix (#925)
1. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset when window_size > context_length
2. Fix bug because of KV cache and GPT's ptr pointer doesn't get reset
3. Fix KV Cache import issue for gpt_with_kv_cache_optimized
2025-12-15 18:47:01 -06:00

189 lines
5.9 KiB
Python

# Code to test the GPT model implementation against the KV cache variants
import pytest
import torch
import tiktoken
from gpt_ch04 import GPTModel as GPTModelBase
from gpt_ch04 import generate_text_simple
from gpt_with_kv_cache import GPTModel as GPTModelKV1
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
from gpt_with_kv_cache import generate_text_simple_cached as generate_text_simple_cachedKV1
from gpt_with_kv_cache_optimized import generate_text_simple_cached as generate_text_simple_cachedKV2
GPT_CONFIG_124M = {
"vocab_size": 50257,
"context_length": 1024,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 1024 # NEW: KV cache window size
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
def test_gpt_model_equivalence_not_cached(ModelClass):
torch.manual_seed(123)
model = ModelClass(GPT_CONFIG_124M).to(device)
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
prompt = "Hello, I am"
encoded = tokenizer.encode(prompt)
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
model_name = ModelClass.__module__ + "." + ModelClass.__name__
token_ids = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
test_gpt_model_equivalence_not_cached.results = []
test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
if len(test_gpt_model_equivalence_not_cached.results) == 3:
base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
def test_gpt_model_equivalence_cached(ModelClass):
torch.manual_seed(123)
model = ModelClass(GPT_CONFIG_124M).to(device)
model.eval()
tokenizer = tiktoken.get_encoding("gpt2")
prompt = "Hello, I am"
encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
model_name = ModelClass.__module__ + "." + ModelClass.__name__
if ModelClass is GPTModelBase:
token_ids = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
elif ModelClass is GPTModelKV1:
token_ids = generate_text_simple_cachedKV1(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
else:
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=encoded_tensor,
max_new_tokens=30,
context_size=GPT_CONFIG_124M["context_length"]
)
if not hasattr(test_gpt_model_equivalence_cached, "results"):
test_gpt_model_equivalence_cached.results = []
test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
if len(test_gpt_model_equivalence_cached.results) == 3:
base_name, base_output = test_gpt_model_equivalence_cached.results[0]
for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
assert torch.equal(base_output, other_output), (
f"Mismatch between {base_name} and {other_name}"
)
def test_context_overflow_bug():
"""
Test that demonstrates the ptr_current_pos overflow bug.
In old implementation:
- context_length = 10 (positions 0-9 available)
- We try to generate 15 tokens total (5 input + 10 generated)
- At token 11 (position 10), it crashes trying to access pos_emb[10]
"""
GPT_CONFIG_SMALL = {
"vocab_size": 50257,
"context_length": 10, # Very small context
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 20 # Larger than context_length
}
torch.manual_seed(123)
model = GPTModelKV2(GPT_CONFIG_SMALL).to(device)
model.eval()
# 5 input tokens
input_tokens = torch.randint(0, 50257, (1, 5), device=device)
generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=10, # 5 + 10 = 15 > 10 context_length
context_size=GPT_CONFIG_SMALL["context_length"],
use_cache=True
)
def test_prefill_chunking_basic():
"""
Test that prefill correctly chunks input when input_length > kv_window_size.
Setup:
- kv_window_size = 4
- input_length = 10
- Should process in 3 chunks: [0:4], [4:8], [8:10]
"""
config = {
"vocab_size": 50257,
"context_length": 20,
"emb_dim": 768,
"n_heads": 12,
"n_layers": 12,
"drop_rate": 0.1,
"qkv_bias": False,
"kv_window_size": 4 # Small window to force chunking
}
torch.manual_seed(123)
model = GPTModelKV2(config).to(device)
model.eval()
# 10 input tokens (> kv_window_size of 4)
input_tokens = torch.randint(0, 50257, (1, 10), device=device)
# Should successfully process all input in chunks
token_ids = generate_text_simple_cachedKV2(
model=model,
idx=input_tokens,
max_new_tokens=2,
use_cache=True
)
# Should have 10 input + 2 generated = 12 total
assert token_ids.shape[1] == 12, f"Expected 12 tokens, got {token_ids.shape[1]}"
# First 10 tokens should match input
assert torch.equal(token_ids[:, :10], input_tokens), "Input tokens should be preserved"