Batched KV Cache Inference for Qwen3 (#735)

This commit is contained in:
Sebastian Raschka
2025-07-10 08:09:35 -05:00
committed by GitHub
parent b8c8237251
commit a354555049
8 changed files with 506 additions and 6 deletions

View File

@@ -292,4 +292,72 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
Note that all settings above have been tested to produce the same text outputs.
Note that all settings above have been tested to produce the same text outputs.
 
#### Pro tip 3: batched inference
We can further increase the throughput via batched inference. While it's not an apples-to-apples comparison, as we are now running inference with a higher number of input sequences, this increases the tokens per second throughput while trading it off against increased memory usage.
This only requires a small code modification with respect to preparing the prompt. For example, consider this batched prompt below:
```python
from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.qwen3 import Qwen3Model, QWEN_CONFIG_06_B
# ...
prompts = [
"Give me a short introduction to neural networks.",
"Give me a short introduction to machine learning.",
"Give me a short introduction to deep learning models.",
"Give me a short introduction to natural language processing.",
"Give me a short introduction to generative AI systems.",
"Give me a short introduction to transformer architectures.",
"Give me a short introduction to supervised learning methods.",
"Give me a short introduction to unsupervised learning.",
]
tokenized_prompts = [tokenizer.encode(p) for p in prompts]
max_len = max(len(t) for t in tokenized_prompts)
padded_token_ids = [
t + [tokenizer.pad_token_id] * (max_len - len(t)) for t in tokenized_prompts
]
input_tensor = torch.tensor(padded_token_ids).to(device)
output_token_ids = generate_text_simple(
model=model,
idx=input_tensor,
max_new_tokens=150,
context_size=QWEN_CONFIG_06_B["context_length"],
)
```
The code for the KV cache version is similar, except that it requires using these drop-in replacements:
```python
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
```
The experiments below are run with a batch size of 8.
| Model | Mode | Hardware | Batch size | Tokens/sec | GPU Memory (VRAM) |
| ---------- | ----------------- | --------------- | ---------- | ---------- | ----------------- |
| Qwen3Model | Regular | Mac Mini M4 CPU | 8 | 2 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 8 | - | - |
| Qwen3Model | KV cache | Mac Mini M4 CPU | 8 | 92 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 8 | 128 | - |
| | | | | | |
| Qwen3Model | Regular | Mac Mini M4 GPU | 8 | 36 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | 8 | - | - |
| Qwen3Model | KV cache | Mac Mini M4 GPU | 8 | 61 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | 8 | - | - |
| | | | | | |
| Qwen3Model | Regular | Nvidia A100 GPU | 8 | 184 | 2.19 GB |
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 8 | 351 | 2.19 GB |
| Qwen3Model | KV cache | Nvidia A100 GPU | 8 | 140 | 3.13 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 8 | 280 | 1.75 GB |