Improve KV cache code for torch.compile (#705)

* Improve KV cache code for torch.compile

* cleanup

* cleanup
This commit is contained in:
Sebastian Raschka
2025-06-23 18:08:49 -05:00
committed by GitHub
parent 6522be94be
commit 81eda38d3b
8 changed files with 593 additions and 315 deletions

View File

@@ -27,7 +27,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.register_buffer( self.register_buffer(
"mask", "mask",
torch.triu(torch.ones(context_length, context_length),diagonal=1), torch.triu(torch.ones(context_length, context_length), diagonal=1),
persistent=False persistent=False
) )

View File

@@ -236,14 +236,14 @@ token_ids = generate_text_simple(
) )
``` ```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements). Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) | | Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
|-------------|-------------------|-----------------|------------|-------------------| | ----------- | ----------------- | --------------- | ---------- | ----------------- |
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - | | Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - | | Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - | | Llama3Model | KV cache | Mac Mini M4 CPU | 68 | - |
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - | | Llama3Model | KV cache compiled | Mac Mini M4 CPU | 86 | - |
| | | | | | | | | | | |
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - | | Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - | | Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
@@ -252,7 +252,7 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
| | | | | | | | | | | |
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB | | Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB | | Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB | | Llama3Model | KV cache | Nvidia A100 GPU | 58 | 2.87 GB |
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB | | Llama3Model | KV cache compiled | Nvidia A100 GPU | 161 | 3.61 GB |
Note that all settings above have been tested to produce the same text outputs. Note that all settings above have been tested to produce the same text outputs.

View File

@@ -209,23 +209,23 @@ token_ids = generate_text_simple(
) )
``` ```
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements). Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) | | Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
| ---------- | ----------------- | --------------- | ---------- | ----------------- | | ---------- | ----------------- | --------------- | ---------- | ----------------- |
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - | | Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 1 | - | | Qwen3Model | Regular compiled | Mac Mini M4 CPU | 1 | - |
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - | | Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 82 | - | | Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 137 | - |
| | | | | | | | | | | |
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - | | Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | Error | - | | Qwen3Model | Regular compiled | Mac Mini M4 GPU | Error | - |
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - | | Qwen3Model | KV cache | Mac Mini M4 GPU | 28 | - |
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | Error | - | | Qwen3Model | KV cache compiled | Mac Mini M4 GPU | Error | - |
| | | | | | | | | | | |
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB | | Qwen3Model | Regular | Nvidia A100 GPU | 26 | 1.49 GB |
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB | | Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB | | Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB | | Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
Note that all settings above have been tested to produce the same text outputs. Note that all settings above have been tested to produce the same text outputs.

View File

@@ -3,23 +3,24 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from .utils import KVCache
import torch import torch
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True): def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval() model.eval()
ctx_len = context_size or model.cfg["context_length"] ctx_len = context_size or model.cfg["context_length"]
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
with torch.no_grad(): with torch.no_grad():
if use_cache: if use_cache:
model.reset_kv_cache() model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True) logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True) next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1) idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True) logits = model(next_idx, use_cache=True, cache=cache)
else: else:
for _ in range(max_new_tokens): for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False) logits = model(idx[:, -ctx_len:], use_cache=False)

View File

@@ -3,6 +3,8 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from .utils import KVCache # noqa: F401
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -11,7 +13,7 @@ import torch.nn as nn
# Chapter 3 # Chapter 3
##################################### #####################################
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None): def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__() super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@@ -25,80 +27,41 @@ class MultiHeadAttention(nn.Module):
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
#################################################### def forward(self, x, use_cache=False, start_pos=0, cache=None):
# NEW
self.max_seq_len = max_seq_len or context_length
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
####################################################
def forward(self, x, use_cache=False):
b, num_tokens, d_in = x.shape b, num_tokens, d_in = x.shape
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out) keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
values_new = self.W_value(x) values = self.W_value(x)
queries = self.W_query(x) queries = self.W_query(x)
# We implicitly split the matrix by adding a `num_heads` dimension # We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim) keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim) values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2) keys = keys.transpose(1, 2)
values_new = values_new.transpose(1, 2)
queries = queries.transpose(1, 2) queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
####################################################
# NEW
if use_cache: if use_cache:
if self.cache_k is None or self.cache_k.size(0) != b: if cache is not None:
self.cache_k = torch.zeros(b, self.num_heads, keys = torch.cat([cache[0], keys], dim=2)
self.window_size, self.head_dim, values = torch.cat([cache[1], values], dim=2)
device=x.device) next_cache = (keys, values)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0 # pointer to next free slot
# if incoming chunk would overflow discard oldest tokens
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
# shift everything left by `overflow` (cheap view-copy)
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow # pointer after shift
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
else: else:
keys, values = keys_new, values_new next_cache = None
self.ptr_cur = 0 # keep pointer sane if you interleave modes
#################################################### seq_len = keys.size(2)
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
causal_mask = causal_mask[:, -num_tokens:][None, None, :, :]
# Compute scaled dot-product attention (aka self-attention) with a causal mask # Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
####################################################
# NEW
K = attn_scores.size(-1)
if num_tokens == K:
# No cache → use the prebaked triangular mask slice
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
else:
# Cached: need to offset the diagonal by (K num_tokens)
offset = K - num_tokens # number of tokens already in cache before this chunk
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
causal_mask = row_idx + offset < col_idx # True where j > i+offset
####################################################
# Use the mask to fill attention scores # Use the mask to fill attention scores
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf) attn_scores.masked_fill_(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights) attn_weights = self.dropout(attn_weights)
@@ -110,13 +73,7 @@ class MultiHeadAttention(nn.Module):
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection context_vec = self.out_proj(context_vec) # optional projection
return context_vec return context_vec, next_cache
####################################################
# NEW
def reset_cache(self):
self.cache_k, self.cache_v = None, None
####################################################
##################################### #####################################
@@ -169,25 +126,17 @@ class TransformerBlock(nn.Module):
context_length=cfg["context_length"], context_length=cfg["context_length"],
num_heads=cfg["n_heads"], num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"], dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"], qkv_bias=cfg["qkv_bias"])
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
)
self.ff = FeedForward(cfg) self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"]) self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"]) self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x, use_cache=False): def forward(self, x, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block # Shortcut connection for attention block
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
x, next_cache = self.att(x, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
####################################################
# NEW
x = self.att(x, use_cache=use_cache)
####################################################
x = self.drop_shortcut(x) x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
@@ -198,7 +147,7 @@ class TransformerBlock(nn.Module):
x = self.drop_shortcut(x) x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
return x return x, next_cache
class GPTModel(nn.Module): class GPTModel(nn.Module):
@@ -208,80 +157,34 @@ class GPTModel(nn.Module):
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"]) self.drop_emb = nn.Dropout(cfg["drop_rate"])
# self.trf_blocks = nn.Sequential( self.trf_blocks = nn.Sequential(
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
####################################################
# NEW
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.ptr_current_pos = 0
####################################################
self.final_norm = LayerNorm(cfg["emb_dim"]) self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
self.current_pos = 0
def forward(self, in_idx, use_cache=False): def forward(self, in_idx, use_cache=False, cache=None):
batch_size, seq_len = in_idx.shape batch_size, seq_len = in_idx.shape
pos = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device)
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(pos)
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) x = self.drop_emb(tok_embeds + pos_embeds)
####################################################
# NEW
if use_cache: if use_cache:
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long) start_pos = self.current_pos
self.ptr_current_pos += seq_len self.current_pos += seq_len
else: else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long) start_pos = 0
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
####################################################
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size] next_cache = []
x = self.drop_emb(x) for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
# x = self.trf_blocks(x) x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
#################################################### if cache:
# NEW cache.update(i, new_cache)
for blk in self.trf_blocks: next_cache.append(new_cache)
x = blk(x, use_cache=use_cache)
####################################################
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x) logits = self.out_head(x)
return logits return logits
####################################################
# NEW
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
####################################################
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx

View File

@@ -3,15 +3,20 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from ..llama3 import Llama3Tokenizer, ChatFormat, clean_text # noqa: F401 from .utils import KVCache # noqa: F401
import os
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
import tiktoken
from tiktoken.load import load_tiktoken_bpe
LLAMA32_CONFIG_1B = { LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size "vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model "context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 2048, # Embedding dimension "emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads "n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers "n_layers": 16, # Number of layers
@@ -30,7 +35,6 @@ LLAMA32_CONFIG_1B = {
LLAMA32_CONFIG_3B = { LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size "vocab_size": 128_256, # Vocabulary size
"context_length": 131_072, # Context length that was used to train the model "context_length": 131_072, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 3072, # Embedding dimension "emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads "n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers "n_layers": 28, # Number of layers
@@ -71,21 +75,45 @@ class Llama3Model(nn.Module):
self.register_buffer("cos", cos, persistent=False) self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False) self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg self.cfg = cfg
self.current_pos = 0 # Track current position in KV cache
def forward(self, in_idx, use_cache=False): def forward(self, in_idx, use_cache=False, cache=None):
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
x = tok_embeds x = tok_embeds
for block in self.trf_blocks: num_tokens = x.shape[1]
x = block(x, self.cos, self.sin, use_cache) if use_cache:
pos_start = self.current_pos
pos_end = pos_start + num_tokens
self.current_pos = pos_end
mask = torch.triu(
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
)[pos_start:pos_end, :pos_end]
else:
pos_start = 0 # Not strictly necessary but helps torch.compile
mask = torch.triu(
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
)
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
use_cache=use_cache,
start_pos=pos_start,
cache=blk_cache)
if cache:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"])) logits = self.out_head(x.to(self.cfg["dtype"]))
return logits return logits
def reset_kv_cache(self): def reset_kv_cache(self):
for blk in self.trf_blocks: self.current_pos = 0
blk.att.reset_cache()
self.ptr_current_pos = 0
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
@@ -96,18 +124,17 @@ class TransformerBlock(nn.Module):
d_out=cfg["emb_dim"], d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"], num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"], num_kv_groups=cfg["n_kv_groups"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"] dtype=cfg["dtype"]
) )
self.ff = FeedForward(cfg) self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin, use_cache=False): def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block # Shortcut connection for attention block
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size] x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block # Shortcut connection for feed-forward block
@@ -116,7 +143,7 @@ class TransformerBlock(nn.Module):
x = self.ff(x) x = self.ff(x)
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
return x return x, next_cache
class FeedForward(nn.Module): class FeedForward(nn.Module):
@@ -135,7 +162,7 @@ class FeedForward(nn.Module):
class GroupedQueryAttention(nn.Module): class GroupedQueryAttention(nn.Module):
def __init__( def __init__(
self, d_in, d_out, num_heads, num_kv_groups, dtype=None, max_seq_len=None, window_size=None self, d_in, d_out, num_heads, num_kv_groups, dtype=None
): ):
super().__init__() super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
@@ -153,45 +180,40 @@ class GroupedQueryAttention(nn.Module):
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
# For optional KV cache def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
self.max_seq_len = max_seq_len b, num_tokens, _ = x.shape
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False): # Apply projections
b, num_tokens, d_in = x.shape queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
queries = self.W_query(x) # Shape: (b, num_tokens, d_out) # Reshape
keys_new = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
values_new = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim) keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Reshape queries, keys, and values
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim)
# Transpose keys, values, and queries
queries = queries.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
values_new = values_new.transpose(1, 2) # Shape: (b, num_kv_groups, num_tokens, head_dim)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE # Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice) queries = apply_rope(queries, cos, sin, offset=start_pos)
queries = apply_rope(queries, cos_slice, sin_slice) keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
if use_cache:
if cache is None:
keys = keys_new
values = values_new
else:
prev_k, prev_v = cache
keys = torch.cat([prev_k, keys_new], dim=2)
values = torch.cat([prev_v, values_new], dim=2)
next_cache = (keys, values)
else:
keys, values = keys_new, values_new
next_cache = None
# Expand keys and values to match the number of heads # Expand keys and values to match the number of heads
# Shape: (b, num_heads, num_tokens, head_dim) # Shape: (b, num_heads, num_tokens, head_dim)
keys_new = keys_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
values_new = values_new.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim) values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)
# For example, before repeat_interleave along dim=1 (query groups): # For example, before repeat_interleave along dim=1 (query groups):
# [K1, K2] # [K1, K2]
# After repeat_interleave (each query group is repeated group_size times): # After repeat_interleave (each query group is repeated group_size times):
@@ -199,38 +221,12 @@ class GroupedQueryAttention(nn.Module):
# If we used regular repeat instead of repeat_interleave, we'd get: # If we used regular repeat instead of repeat_interleave, we'd get:
# [K1, K2, K1, K2] # [K1, K2, K1, K2]
if use_cache:
if not self.cache_initialized:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype)
self.ptr = 0
self.cache_initialized = True
# In-place update
end = self.ptr + num_tokens
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else:
keys, values = keys_new, values_new
# Compute scaled dot-product attention (aka self-attention) with a causal mask # Compute scaled dot-product attention (aka self-attention) with a causal mask
# Shape: (b, num_heads, num_tokens, num_tokens) # Shape: (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Create causal mask to fill attention scores # Use the mask to fill attention scores
T_q = queries.shape[-2] attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
assert keys.shape[-1] == self.head_dim assert keys.shape[-1] == self.head_dim
@@ -242,13 +238,7 @@ class GroupedQueryAttention(nn.Module):
context_vec = context_vec.reshape(b, num_tokens, self.d_out) context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection context_vec = self.out_proj(context_vec) # optional projection
return context_vec return context_vec, next_cache
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32): def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):
@@ -296,7 +286,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_c
return cos, sin return cos, sin
def apply_rope(x, cos, sin): def apply_rope(x, cos, sin, offset=9):
# x: (batch_size, num_heads, seq_len, head_dim) # x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even" assert head_dim % 2 == 0, "Head dimension must be even"
@@ -306,8 +296,8 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes # Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation # Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1) rotated = torch.cat((-x2, x1), dim=-1)
@@ -315,3 +305,231 @@ def apply_rope(x, cos, sin):
# It's ok to use lower-precision after applying cos and sin rotation # It's ok to use lower-precision after applying cos and sin rotation
return x_rotated.to(dtype=x.dtype) return x_rotated.to(dtype=x.dtype)
##########################################
# Tokenizer
##########################################
class Llama3Tokenizer:
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
def __init__(self, model_path):
if not os.path.isfile(model_path):
raise FileNotFoundError(model_path)
mergeable = load_tiktoken_bpe(model_path)
# hard-coded from Meta's tokenizer.json
self.special = {
"<|begin_of_text|>": 128000,
"<|end_of_text|>": 128001,
"<|start_header_id|>": 128006,
"<|end_header_id|>": 128007,
"<|eot_id|>": 128009,
}
self.special.update({f"<|reserved_{i}|>": 128002 + i
for i in range(256)
if 128002 + i not in self.special.values()})
self.model = tiktoken.Encoding(
name=Path(model_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
r"|\p{N}{1,3}"
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
r"|\s*[\r\n]+"
r"|\s+(?!\S)"
r"|\s+",
mergeable_ranks=mergeable,
special_tokens=self.special,
)
def encode(self, text, bos=False, eos=False, **kwargs):
ids = ([self.special["<|begin_of_text|>"]] if bos else []) \
+ self.model.encode(text)
if eos:
ids.append(self.special["<|end_of_text|>"])
return ids
def decode(self, ids):
return self.model.decode(ids)
class ChatFormat:
def __init__(self, tokenizer: Llama3Tokenizer, *,
default_system="You are a helpful assistant."):
self.tok = tokenizer
self.default_system = default_system
def _header(self, role):
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
return (
[self.tok.special["<|start_header_id|>"]]
+ self.tok.encode(role)
+ [self.tok.special["<|end_header_id|>"]]
+ self.tok.encode("\n\n")
)
def encode(self, user_message, system_message=None, allowed_special=None):
sys_msg = system_message if system_message is not None else self.default_system
ids = [self.tok.special["<|begin_of_text|>"]]
# system
ids += self._header("system")
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
ids += [self.tok.special["<|eot_id|>"]]
# user
ids += self._header("user")
ids += self.tok.encode(user_message)
ids += [self.tok.special["<|eot_id|>"]]
# assistant header (no content yet)
ids += self._header("assistant")
return ids
def decode(self, ids):
return self.tok.decode(ids)
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
# Find the index of the first occurrence of "<|end_header_id|>"
index = text.find(header_end)
if index != -1:
# Return the substring starting after "<|end_header_id|>"
return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace
else:
# If the token is not found, return the original text
return text
######################################################################
# Llama 3 fast (alternative code geared towards efficiency)
######################################################################
class GroupedQueryAttentionFast(nn.Module):
"""
Drop-in replacement for GroupedQueryAttention but using PyTorch's
scaled_dot_product_attention, which uses FlashAttention if run
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
"""
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
def forward(self, x, cos, sin):
b, num_tokens, _ = x.shape
# Project to queries, keys, values
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Apply Rotary Positional Embedding
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
# Expand key/value groups to full head count
k = k.repeat_interleave(self.group_size, dim=1)
v = v.repeat_interleave(self.group_size, dim=1)
# Efficient scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
is_causal=True # Enables Flash/FlexAttention kernels
)
# Combine heads and project
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(attn_output)
class TransformerBlockFast(nn.Module):
"""
Same as original TransformerBlock but uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttentionFast(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
def forward(self, x, cos, sin):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back
return x
class Llama3ModelFast(nn.Module):
"""
Same as original Llama3Model but uses TransformerBlockFast
instead of TransformerBlock, which in turn uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()
# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])
self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
)
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
freq_config=cfg["rope_freq"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg
def forward(self, in_idx):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds
for block in self.trf_blocks:
x = block(x, self.cos, self.sin)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits

View File

@@ -3,7 +3,11 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch # - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from ..qwen3 import Qwen3Tokenizer, download_from_huggingface, load_weights_into_qwen # noqa: F401 from .utils import KVCache # noqa: F401
import os
import urllib.request
from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -12,7 +16,6 @@ import torch.nn as nn
QWEN_CONFIG_06_B = { QWEN_CONFIG_06_B = {
"vocab_size": 151_936, # Vocabulary size "vocab_size": 151_936, # Vocabulary size
"context_length": 40_960, # Context length that was used to train the model "context_length": 40_960, # Context length that was used to train the model
"window_size": None, # Window size for the KV cache; context_length if None
"emb_dim": 1024, # Embedding dimension "emb_dim": 1024, # Embedding dimension
"n_heads": 16, # Number of attention heads "n_heads": 16, # Number of attention heads
"n_layers": 28, # Number of layers "n_layers": 28, # Number of layers
@@ -51,22 +54,46 @@ class Qwen3Model(nn.Module):
self.register_buffer("cos", cos, persistent=False) self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False) self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg self.cfg = cfg
self.current_pos = 0 # Track current position in KV cache
def forward(self, in_idx, use_cache=False): def forward(self, in_idx, use_cache=False, cache=None):
# Forward pass # Forward pass
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
x = tok_embeds x = tok_embeds
for block in self.trf_blocks: num_tokens = x.shape[1]
x = block(x, self.cos, self.sin, use_cache) if use_cache:
pos_start = self.current_pos
pos_end = pos_start + num_tokens
self.current_pos = pos_end
mask = torch.triu(
torch.ones(pos_end, pos_end, device=x.device, dtype=torch.bool), diagonal=1
)[pos_start:pos_end, :pos_end]
else:
pos_start = 0 # Not strictly necessary but helps torch.compile
mask = torch.triu(
torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1
)
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
use_cache=use_cache,
start_pos=pos_start,
cache=blk_cache)
if cache:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"])) logits = self.out_head(x.to(self.cfg["dtype"]))
return logits return logits
def reset_kv_cache(self): def reset_kv_cache(self):
for blk in self.trf_blocks: self.current_pos = 0
blk.att.reset_cache()
self.ptr_current_pos = 0
class TransformerBlock(nn.Module): class TransformerBlock(nn.Module):
@@ -78,18 +105,17 @@ class TransformerBlock(nn.Module):
head_dim=cfg["head_dim"], head_dim=cfg["head_dim"],
num_kv_groups=cfg["n_kv_groups"], num_kv_groups=cfg["n_kv_groups"],
qk_norm=cfg["qk_norm"], qk_norm=cfg["qk_norm"],
max_seq_len=cfg["context_length"],
dtype=cfg["dtype"] dtype=cfg["dtype"]
) )
self.ff = FeedForward(cfg) self.ff = FeedForward(cfg)
self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6) self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6)
self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6) self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6)
def forward(self, x, cos, sin, use_cache=False): def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
# Shortcut connection for attention block # Shortcut connection for attention block
shortcut = x shortcut = x
x = self.norm1(x) x = self.norm1(x)
x = self.att(x, cos, sin, use_cache) # Shape [batch_size, num_tokens, emb_size] x, next_cache = self.att(x, mask, cos, sin, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block # Shortcut connection for feed-forward block
@@ -98,7 +124,7 @@ class TransformerBlock(nn.Module):
x = self.ff(x) x = self.ff(x)
x = x + shortcut # Add the original input back x = x + shortcut # Add the original input back
return x return x, next_cache
class FeedForward(nn.Module): class FeedForward(nn.Module):
@@ -117,8 +143,7 @@ class FeedForward(nn.Module):
class GroupedQueryAttention(nn.Module): class GroupedQueryAttention(nn.Module):
def __init__( def __init__(
self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None, self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None
max_seq_len=None, window_size=None
): ):
super().__init__() super().__init__()
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"
@@ -146,26 +171,18 @@ class GroupedQueryAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = None self.q_norm = self.k_norm = None
# For optional KV cache def forward(self, x, mask, cos, sin, use_cache=False, start_pos=0, cache=None):
self.max_seq_len = max_seq_len
self.window_size = window_size or self.max_seq_len
self.register_buffer("cache_k", None, persistent=False)
self.register_buffer("cache_v", None, persistent=False)
self.cache_initialized = False
self.ptr = 0
def forward(self, x, cos, sin, use_cache=False):
b, num_tokens, _ = x.shape b, num_tokens, _ = x.shape
# Apply projections # Apply projections
queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim) queries = self.W_query(x) # (b, num_tokens, num_heads * head_dim)
keys_new = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim) keys = self.W_key(x) # (b, num_tokens, num_kv_groups * head_dim)
values_new = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim) values = self.W_value(x) # (b, num_tokens, num_kv_groups * head_dim)
# Reshape # Reshape
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
keys_new = keys_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) keys_new = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
values_new = values_new.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) values_new = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
# Optional normalization # Optional normalization
if self.q_norm: if self.q_norm:
@@ -173,62 +190,34 @@ class GroupedQueryAttention(nn.Module):
if self.k_norm: if self.k_norm:
keys_new = self.k_norm(keys_new) keys_new = self.k_norm(keys_new)
# For KV cache
pos_start = self.ptr
pos_end = pos_start + num_tokens
cos_slice = cos[pos_start:pos_end]
sin_slice = sin[pos_start:pos_end]
# Apply RoPE # Apply RoPE
keys_new = apply_rope(keys_new, cos_slice, sin_slice) queries = apply_rope(queries, cos, sin, offset=start_pos)
queries = apply_rope(queries, cos_slice, sin_slice) keys_new = apply_rope(keys_new, cos, sin, offset=start_pos)
# Expand K and V to match number of heads
keys_new = keys_new.repeat_interleave(self.group_size, dim=1)
values_new = values_new.repeat_interleave(self.group_size, dim=1)
if use_cache: if use_cache:
if not self.cache_initialized: if cache is None:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=keys_new.dtype) keys = keys_new
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device, dtype=values_new.dtype) values = values_new
self.ptr = 0 else:
self.cache_initialized = True prev_k, prev_v = cache
keys = torch.cat([prev_k, keys_new], dim=2)
# In-place update values = torch.cat([prev_v, values_new], dim=2)
end = self.ptr + num_tokens next_cache = (keys, values)
self.cache_k[:, :, self.ptr:end].copy_(keys_new)
self.cache_v[:, :, self.ptr:end].copy_(values_new)
keys = self.cache_k[:, :, max(0, end - self.window_size):end]
values = self.cache_v[:, :, max(0, end - self.window_size):end]
self.ptr = end
else: else:
keys, values = keys_new, values_new keys, values = keys_new, values_new
next_cache = None
# Expand K and V to match number of heads
keys = keys.repeat_interleave(self.group_size, dim=1)
values = values.repeat_interleave(self.group_size, dim=1)
# Attention # Attention
attn_scores = queries @ keys.transpose(2, 3) attn_scores = queries @ keys.transpose(2, 3)
attn_scores = attn_scores.masked_fill(mask, -torch.inf)
# Create causal mask to fill attention scores
T_q = queries.shape[-2]
T_k = keys.shape[-2]
if not use_cache or T_q > 1:
causal_mask = torch.triu(
torch.ones((T_q, T_k), device=x.device, dtype=torch.bool),
diagonal=1
)
attn_scores = attn_scores.masked_fill(causal_mask, -torch.inf)
attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1) attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)
context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out) context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(context) return self.out_proj(context), next_cache
def reset_cache(self):
if self.cache_k is not None:
self.cache_k.zero_()
self.cache_v.zero_()
self.ptr = 0
def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32): def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32):
@@ -253,7 +242,7 @@ def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=
return cos, sin return cos, sin
def apply_rope(x, cos, sin): def apply_rope(x, cos, sin, offset=0):
# x: (batch_size, num_heads, seq_len, head_dim) # x: (batch_size, num_heads, seq_len, head_dim)
batch_size, num_heads, seq_len, head_dim = x.shape batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim % 2 == 0, "Head dimension must be even" assert head_dim % 2 == 0, "Head dimension must be even"
@@ -263,8 +252,8 @@ def apply_rope(x, cos, sin):
x2 = x[..., head_dim // 2:] # Second half x2 = x[..., head_dim // 2:] # Second half
# Adjust sin and cos shapes # Adjust sin and cos shapes
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim) cos = cos[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) sin = sin[offset:offset + seq_len, :].unsqueeze(0).unsqueeze(0)
# Apply the rotary transformation # Apply the rotary transformation
rotated = torch.cat((-x2, x1), dim=-1) rotated = torch.cat((-x2, x1), dim=-1)
@@ -297,3 +286,149 @@ class RMSNorm(nn.Module):
return norm_x.to(input_dtype) return norm_x.to(input_dtype)
def load_weights_into_qwen(model, param_config, params):
def assign(left, right, tensor_name="unknown"):
if left.shape != right.shape:
raise ValueError(f"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}")
return torch.nn.Parameter(right.clone().detach() if isinstance(right, torch.Tensor) else torch.tensor(right))
model.tok_emb.weight = assign(model.tok_emb.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
for l in range(param_config["n_layers"]):
block = model.trf_blocks[l]
att = block.att
# Q, K, V projections
att.W_query.weight = assign(
att.W_query.weight,
params[f"model.layers.{l}.self_attn.q_proj.weight"],
f"model.layers.{l}.self_attn.q_proj.weight"
)
att.W_key.weight = assign(
att.W_key.weight,
params[f"model.layers.{l}.self_attn.k_proj.weight"],
f"model.layers.{l}.self_attn.k_proj.weight"
)
att.W_value.weight = assign(
att.W_value.weight,
params[f"model.layers.{l}.self_attn.v_proj.weight"],
f"model.layers.{l}.self_attn.v_proj.weight"
)
# Output projection
att.out_proj.weight = assign(
att.out_proj.weight,
params[f"model.layers.{l}.self_attn.o_proj.weight"],
f"model.layers.{l}.self_attn.o_proj.weight"
)
# QK norms
if hasattr(att, "q_norm") and att.q_norm is not None:
att.q_norm.scale = assign(
att.q_norm.scale,
params[f"model.layers.{l}.self_attn.q_norm.weight"],
f"model.layers.{l}.self_attn.q_norm.weight"
)
if hasattr(att, "k_norm") and att.k_norm is not None:
att.k_norm.scale = assign(
att.k_norm.scale,
params[f"model.layers.{l}.self_attn.k_norm.weight"],
f"model.layers.{l}.self_attn.k_norm.weight"
)
# Attention layernorm
block.norm1.scale = assign(
block.norm1.scale,
params[f"model.layers.{l}.input_layernorm.weight"],
f"model.layers.{l}.input_layernorm.weight"
)
# Feedforward weights
block.ff.fc1.weight = assign(
block.ff.fc1.weight,
params[f"model.layers.{l}.mlp.gate_proj.weight"],
f"model.layers.{l}.mlp.gate_proj.weight"
)
block.ff.fc2.weight = assign(
block.ff.fc2.weight,
params[f"model.layers.{l}.mlp.up_proj.weight"],
f"model.layers.{l}.mlp.up_proj.weight"
)
block.ff.fc3.weight = assign(
block.ff.fc3.weight,
params[f"model.layers.{l}.mlp.down_proj.weight"],
f"model.layers.{l}.mlp.down_proj.weight"
)
block.norm2.scale = assign(
block.norm2.scale,
params[f"model.layers.{l}.post_attention_layernorm.weight"],
f"model.layers.{l}.post_attention_layernorm.weight"
)
# Final normalization and output head
model.final_norm.scale = assign(model.final_norm.scale, params["model.norm.weight"], "model.norm.weight")
# Model uses weight tying, hence we reuse the embedding layer weights here
model.out_head.weight = assign(model.out_head.weight, params["model.embed_tokens.weight"], "model.embed_tokens.weight")
class Qwen3Tokenizer():
def __init__(self, tokenizer_file_path="tokenizer.json",
repo_id=None, add_generation_prompt=False, add_thinking=False):
from tokenizers import Tokenizer
self.tokenizer_file_path = tokenizer_file_path
if add_generation_prompt != add_thinking:
raise ValueError(
"Only add_generation_prompt==add_thinking settings are currently supported"
)
self.add_generation_prompt = add_generation_prompt
self.add_thinking = add_thinking
tokenizer_file_path_obj = Path(tokenizer_file_path)
if not tokenizer_file_path_obj.is_file() and repo_id is not None:
_ = download_from_huggingface(
repo_id=repo_id,
filename=str(tokenizer_file_path_obj.name),
local_dir=str(tokenizer_file_path_obj.parent.name)
)
self.tokenizer = Tokenizer.from_file(tokenizer_file_path)
def encode(self, prompt):
messages = [
{"role": "user", "content": prompt}
]
formatted_prompt = self.format_qwen_chat(
messages,
add_generation_prompt=self.add_generation_prompt,
add_thinking=self.add_thinking
)
return self.tokenizer.encode(formatted_prompt).ids
def decode(self, token_ids):
return self.tokenizer.decode(token_ids, skip_special_tokens=False)
@staticmethod
def format_qwen_chat(messages, add_generation_prompt=False, add_thinking=False):
prompt = ""
for msg in messages:
prompt += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
if add_generation_prompt:
prompt += "<|im_start|>assistant"
if not add_thinking:
prompt += "<|think>\n\n<|/think>\n\n"
else:
prompt += "\n"
return prompt
def download_from_huggingface(repo_id, filename, local_dir, revision="main"):
base_url = "https://huggingface.co"
url = f"{base_url}/{repo_id}/resolve/{revision}/{filename}"
Path(local_dir).mkdir(parents=True, exist_ok=True)
dest_path = os.path.join(local_dir, filename)
print(f"Downloading {url} to {dest_path}...")
urllib.request.urlretrieve(url, dest_path)
return dest_path

View File

@@ -0,0 +1,21 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
class KVCache:
def __init__(self, n_layers):
self.cache = [None] * n_layers
def get(self, layer_idx):
return self.cache[layer_idx]
def update(self, layer_idx, value):
self.cache[layer_idx] = value
def get_all(self):
return self.cache
def reset(self):
for i in range(len(self.cache)):
self.cache[i] = None