Optimize KV cache (#673)

* Optimize KV cache

* style

* interpretable generate

* interpretable generate

* update readme
This commit is contained in:
Sebastian Raschka
2025-06-16 16:00:50 -05:00
committed by GitHub
parent ba0370abd1
commit ece59ba587
4 changed files with 98 additions and 68 deletions

View File

@@ -56,28 +56,29 @@ class MultiHeadAttention(nn.Module):
# NEW
if use_cache:
if self.cache_k is None or self.cache_k.size(0) != b:
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
self.current_pos = 0
self.cache_k = torch.zeros(b, self.num_heads,
self.window_size, self.head_dim,
device=x.device)
self.cache_v = torch.zeros_like(self.cache_k)
self.ptr_cur = 0 # pointer to next free slot
# write new entries
start = self.current_pos
end = start + num_tokens
self.cache_k[:, :, start:end, :] = keys_new
self.cache_v[:, :, start:end, :] = values_new
self.current_pos = end
# if incoming chunk would overflow discard oldest tokens
if self.ptr_cur + num_tokens > self.window_size:
overflow = self.ptr_cur + num_tokens - self.window_size
# shift everything left by `overflow` (cheap view-copy)
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
self.ptr_cur -= overflow # pointer after shift
# sliding window truncation
if self.current_pos > self.window_size:
self.cache_k = self.cache_k[:, :, -self.window_size:, :]
self.cache_v = self.cache_v[:, :, -self.window_size:, :]
self.current_pos = self.window_size
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
self.ptr_cur += num_tokens
keys = self.cache_k[:, :, :self.current_pos, :]
values = self.cache_v[:, :, :self.current_pos, :]
keys = self.cache_k[:, :, :self.ptr_cur, :]
values = self.cache_v[:, :, :self.ptr_cur, :]
else:
keys = keys_new
values = values_new
keys, values = keys_new, values_new
self.ptr_cur = 0 # keep pointer sane if you interleave modes
####################################################
@@ -216,7 +217,7 @@ class GPTModel(nn.Module):
self.trf_blocks = nn.ModuleList(
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.current_pos = 0
self.ptr_current_pos = 0
####################################################
self.final_norm = LayerNorm(cfg["emb_dim"])
@@ -232,8 +233,8 @@ class GPTModel(nn.Module):
# NEW
if use_cache:
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.current_pos += seq_len
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
self.ptr_current_pos += seq_len
else:
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
@@ -258,7 +259,7 @@ class GPTModel(nn.Module):
def reset_kv_cache(self):
for blk in self.trf_blocks:
blk.att.reset_cache()
self.ptr_current_pos = 0
####################################################
@@ -290,19 +291,30 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
####################################################
# NEW
def generate_text_simple_cached(model, idx, max_new_tokens):
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
model.eval()
model.reset_kv_cache()
# Init cache with full prompt
logits = model(idx, use_cache=True)
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
if use_cache:
# Init cache with full prompt
model.reset_kv_cache()
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=True)
for _ in range(max_new_tokens):
last_logits = logits[:, -1]
next_idx = last_logits.argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True)
for _ in range(max_new_tokens):
# a) pick the token with the highest log-probability (greedy sampling)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# b) append it to the running sequence
idx = torch.cat([idx, next_idx], dim=1)
# c) feed model only the new token
with torch.no_grad():
logits = model(next_idx, use_cache=True)
else:
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(idx[:, -ctx_len:], use_cache=False)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
return idx
####################################################
@@ -317,7 +329,7 @@ def main():
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False, # Query-Key-Value bias
"kv_window_size": 48 # NEW: KV cache window size
"kv_window_size": 1024 # NEW: KV cache window size
}
torch.manual_seed(123)