mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Simplify KV cache usage (#728)
* Simplify KV cache usage * Swap mark text with ghostwriter
This commit is contained in:
@@ -10,20 +10,20 @@ import torch
|
||||
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
|
||||
model.eval()
|
||||
ctx_len = context_size or model.cfg["context_length"]
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
|
||||
|
||||
with torch.no_grad():
|
||||
if use_cache:
|
||||
cache = KVCache(n_layers=model.cfg["n_layers"])
|
||||
model.reset_kv_cache()
|
||||
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
|
||||
logits = model(idx[:, -ctx_len:], cache=cache)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
logits = model(next_idx, use_cache=True, cache=cache)
|
||||
logits = model(next_idx, cache=cache)
|
||||
else:
|
||||
for _ in range(max_new_tokens):
|
||||
logits = model(idx[:, -ctx_len:], use_cache=False)
|
||||
logits = model(idx[:, -ctx_len:], cache=None)
|
||||
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
|
||||
idx = torch.cat([idx, next_idx], dim=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user