mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Batched KV Cache Inference for Qwen3 (#735)
This commit is contained in:
committed by
GitHub
parent
b8c8237251
commit
a354555049
@@ -292,4 +292,72 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
|
||||
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
||||
Note that all settings above have been tested to produce the same text outputs.
|
||||
|
||||
|
||||
|
||||
#### Pro tip 3: batched inference
|
||||
|
||||
We can further increase the throughput via batched inference. While it's not an apples-to-apples comparison, as we are now running inference with a higher number of input sequences, this increases the tokens per second throughput while trading it off against increased memory usage.
|
||||
|
||||
This only requires a small code modification with respect to preparing the prompt. For example, consider this batched prompt below:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.qwen3 import Qwen3Model, QWEN_CONFIG_06_B
|
||||
# ...
|
||||
|
||||
prompts = [
|
||||
"Give me a short introduction to neural networks.",
|
||||
"Give me a short introduction to machine learning.",
|
||||
"Give me a short introduction to deep learning models.",
|
||||
"Give me a short introduction to natural language processing.",
|
||||
"Give me a short introduction to generative AI systems.",
|
||||
"Give me a short introduction to transformer architectures.",
|
||||
"Give me a short introduction to supervised learning methods.",
|
||||
"Give me a short introduction to unsupervised learning.",
|
||||
]
|
||||
|
||||
tokenized_prompts = [tokenizer.encode(p) for p in prompts]
|
||||
max_len = max(len(t) for t in tokenized_prompts)
|
||||
padded_token_ids = [
|
||||
t + [tokenizer.pad_token_id] * (max_len - len(t)) for t in tokenized_prompts
|
||||
]
|
||||
input_tensor = torch.tensor(padded_token_ids).to(device)
|
||||
|
||||
output_token_ids = generate_text_simple(
|
||||
model=model,
|
||||
idx=input_tensor,
|
||||
max_new_tokens=150,
|
||||
context_size=QWEN_CONFIG_06_B["context_length"],
|
||||
)
|
||||
```
|
||||
|
||||
The code for the KV cache version is similar, except that it requires using these drop-in replacements:
|
||||
|
||||
```python
|
||||
from llms_from_scratch.kv_cache_batched.generate import generate_text_simple
|
||||
from llms_from_scratch.kv_cache_batched.qwen3 import Qwen3Model
|
||||
```
|
||||
|
||||
|
||||
The experiments below are run with a batch size of 8.
|
||||
|
||||
| Model | Mode | Hardware | Batch size | Tokens/sec | GPU Memory (VRAM) |
|
||||
| ---------- | ----------------- | --------------- | ---------- | ---------- | ----------------- |
|
||||
| Qwen3Model | Regular | Mac Mini M4 CPU | 8 | 2 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 8 | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 CPU | 8 | 92 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 8 | 128 | - |
|
||||
| | | | | | |
|
||||
| Qwen3Model | Regular | Mac Mini M4 GPU | 8 | 36 | - |
|
||||
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | 8 | - | - |
|
||||
| Qwen3Model | KV cache | Mac Mini M4 GPU | 8 | 61 | - |
|
||||
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | 8 | - | - |
|
||||
| | | | | | |
|
||||
| Qwen3Model | Regular | Nvidia A100 GPU | 8 | 184 | 2.19 GB |
|
||||
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 8 | 351 | 2.19 GB |
|
||||
| Qwen3Model | KV cache | Nvidia A100 GPU | 8 | 140 | 3.13 GB |
|
||||
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 8 | 280 | 1.75 GB |
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user