Qwen3 KV cache (#688)

This commit is contained in:
Sebastian Raschka
2025-06-21 17:34:39 -05:00
committed by GitHub
parent 2a530b49fe
commit 0b15a00574
8 changed files with 370 additions and 11 deletions

View File

@@ -189,7 +189,7 @@ def llama3_weights_path(tmp_path_factory):
)
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
def test_model_variants(ModelClass, generate_fn, llama3_weights_path):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):

View File

@@ -12,6 +12,9 @@ from llms_from_scratch.qwen3 import (
Qwen3Model,
Qwen3Tokenizer
)
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
import importlib
import pytest
@@ -110,8 +113,16 @@ def qwen3_weights_path(tmp_path_factory):
return path
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
# Skip incompatible combinations
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
return
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
return
torch.manual_seed(123)
model = ModelClass(QWEN_CONFIG_06_B)
model.load_state_dict(torch.load(qwen3_weights_path))