mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Qwen3 KV cache (#688)
This commit is contained in:
committed by
GitHub
parent
2a530b49fe
commit
0b15a00574
@@ -189,7 +189,7 @@ def llama3_weights_path(tmp_path_factory):
|
||||
)
|
||||
@pytest.mark.parametrize("ModelClass", [Llama3Model, Llama3ModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_gpt_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
def test_model_variants(ModelClass, generate_fn, llama3_weights_path):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
|
||||
@@ -12,6 +12,9 @@ from llms_from_scratch.qwen3 import (
|
||||
Qwen3Model,
|
||||
Qwen3Tokenizer
|
||||
)
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
@@ -110,8 +113,16 @@ def qwen3_weights_path(tmp_path_factory):
|
||||
return path
|
||||
|
||||
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model])
|
||||
def test_gpt_model_variants(ModelClass, qwen3_weights_path):
|
||||
@pytest.mark.parametrize("ModelClass", [Qwen3Model, Qwen3ModelKV])
|
||||
@pytest.mark.parametrize("generate_fn", [generate_text_simple, generate_text_simple_cached])
|
||||
def test_model_variants(ModelClass, qwen3_weights_path, generate_fn):
|
||||
|
||||
# Skip incompatible combinations
|
||||
if generate_fn is generate_text_simple and getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
if generate_fn is generate_text_simple_cached and not getattr(ModelClass, "reset_kv_cache", False):
|
||||
return
|
||||
|
||||
torch.manual_seed(123)
|
||||
model = ModelClass(QWEN_CONFIG_06_B)
|
||||
model.load_state_dict(torch.load(qwen3_weights_path))
|
||||
|
||||
Reference in New Issue
Block a user