Add GPT-2 KV cache to pkg (#687)

This commit is contained in:
Sebastian Raschka
2025-06-21 12:29:04 -05:00
committed by GitHub
parent 3be0f3202a
commit fdc3e1b701
4 changed files with 315 additions and 5 deletions

View File

@@ -4,7 +4,9 @@
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
from llms_from_scratch.kv_cache.gpt2 import GPTModel as GPTModelKV
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
import pytest
import torch
@@ -22,8 +24,16 @@ GPT_CONFIG_124M = {
}
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
def test_gpt_model_variants(ModelClass):
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast, GPTModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(GPT_CONFIG_124M)
model.eval() # disable dropout
@@ -39,7 +49,7 @@ def test_gpt_model_variants(ModelClass):
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
out = generate_fn(
model=model,
idx=encoded_tensor,
max_new_tokens=10,