Simplify KV cache usage (#728)

* Simplify KV cache usage

* Swap mark text with ghostwriter
This commit is contained in:
Sebastian Raschka
2025-07-08 12:56:55 -05:00
committed by GitHub
parent b5bd8d2de2
commit 90c824506c
4 changed files with 31 additions and 39 deletions

View File

@@ -10,20 +10,20 @@ import torch
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
model.eval()
ctx_len = context_size or model.cfg["context_length"]
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
with torch.no_grad():
if use_cache:
cache = KVCache(n_layers=model.cfg["n_layers"])
model.reset_kv_cache()
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
logits = model(idx[:, -ctx_len:], cache=cache)
for _ in range(max_new_tokens):
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)
logits = model(next_idx, use_cache=True, cache=cache)
logits = model(next_idx, cache=cache)
else:
for _ in range(max_new_tokens):
logits = model(idx[:, -ctx_len:], use_cache=False)
logits = model(idx[:, -ctx_len:], cache=None)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
idx = torch.cat([idx, next_idx], dim=1)