mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add GPT-2 KV cache to pkg (#687)
This commit is contained in:
committed by
GitHub
parent
3be0f3202a
commit
fdc3e1b701
@@ -4,7 +4,9 @@
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch04 import GPTModel, GPTModelFast
|
||||
from llms_from_scratch.kv_cache.gpt2 import GPTModel as GPTModelKV
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -22,8 +24,16 @@ GPT_CONFIG_124M = {
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast])
|
||||
def test_gpt_model_variants(ModelClass):
|
||||
@pytest.mark.parametrize("ModelClass", [GPTModel, GPTModelFast, GPTModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_gpt_model_variants(ModelClass, generate_fn):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(GPT_CONFIG_124M)
|
||||
model.eval() # disable dropout
|
||||
@@ -39,7 +49,7 @@ def test_gpt_model_variants(ModelClass):
|
||||
print("Encoded input text:", encoded)
|
||||
print("encoded_tensor.shape:", encoded_tensor.shape)
|
||||
|
||||
out = generate_text_simple(
|
||||
out = generate_fn(
|
||||
model=model,
|
||||
idx=encoded_tensor,
|
||||
max_new_tokens=10,
|
||||
|
||||
Reference in New Issue
Block a user