mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Batched KV Cache Inference for Qwen3 (#735)
This commit is contained in:
committed by
GitHub
parent
b8c8237251
commit
a354555049
@@ -15,8 +15,8 @@ from llms_from_scratch.qwen3 import (
|
||||
from llms_from_scratch.kv_cache.qwen3 import Qwen3Model as Qwen3ModelKV
|
||||
from llms_from_scratch.kv_cache.generate import generate_text_simple as generate_text_simple_cached
|
||||
|
||||
# from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
# from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model as Qwen3ModelKVBatched
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple as generate_text_simple_batched
|
||||
|
||||
import importlib
|
||||
import pytest
|
||||
@@ -172,7 +172,7 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids = torch.tensor([input_token_ids])
|
||||
|
||||
out_noKV = generate_text_simple_cached(
|
||||
out_KV = generate_text_simple_cached(
|
||||
model=model_KV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
@@ -185,7 +185,7 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
model_noKV.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_noKV.eval()
|
||||
|
||||
out_KV = generate_text_simple(
|
||||
out_noKV = generate_text_simple(
|
||||
model=model_noKV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
@@ -195,6 +195,69 @@ def test_model_KV_noKV(qwen3_weights_path):
|
||||
assert torch.equal(out_noKV, out_KV)
|
||||
|
||||
|
||||
def test_model_batched_KV(qwen3_weights_path):
|
||||
|
||||
torch.manual_seed(123)
|
||||
model_KV = Qwen3ModelKV(QWEN_CONFIG_06_B)
|
||||
model_KV.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_KV.eval()
|
||||
|
||||
tokenizer = Qwen3Tokenizer(
|
||||
tokenizer_file_path="tokenizer-base.json",
|
||||
repo_id="rasbt/qwen3-from-scratch",
|
||||
add_generation_prompt=False,
|
||||
add_thinking=False
|
||||
)
|
||||
|
||||
# Batch size 1
|
||||
|
||||
prompt = "Give me a short introduction to large language models."
|
||||
input_token_ids = tokenizer.encode(prompt)
|
||||
input_token_ids = torch.tensor([input_token_ids])
|
||||
|
||||
out_KV = generate_text_simple_cached(
|
||||
model=model_KV,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
del model_KV
|
||||
|
||||
torch.manual_seed(123)
|
||||
model_KV_batched = Qwen3ModelKVBatched(QWEN_CONFIG_06_B)
|
||||
model_KV_batched.load_state_dict(torch.load(qwen3_weights_path))
|
||||
model_KV_batched.eval()
|
||||
|
||||
out_KV_bs_1 = generate_text_simple_batched(
|
||||
model=model_KV_batched,
|
||||
idx=input_token_ids,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"]
|
||||
)
|
||||
|
||||
assert torch.equal(out_KV, out_KV_bs_1)
|
||||
|
||||
# Batch size 2
|
||||
|
||||
prompts = [
|
||||
"Give me a short introduction to large language models.",
|
||||
"Give me a short introduction to large language models."
|
||||
]
|
||||
tokenized_prompts = [tokenizer.encode(p) for p in prompts]
|
||||
max_len = max(len(t) for t in tokenized_prompts)
|
||||
padded_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_len - len(t)) for t in tokenized_prompts
|
||||
]
|
||||
input_tensor = torch.tensor(padded_token_ids)
|
||||
out_KV_bs_2 = generate_text_simple_batched(
|
||||
model=model_KV_batched,
|
||||
idx=input_tensor,
|
||||
max_new_tokens=5,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
)
|
||||
assert torch.equal(out_KV.squeeze(0), out_KV_bs_2[0]), (out_KV.squeeze(0).shape, out_KV_bs_2[0].shape)
|
||||
|
||||
|
||||
def test_rmsnorm_equivalence():
|
||||
torch.manual_seed(42)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user