fix: added KVcache in generate_text_basic_stream (#981)

This commit is contained in:
casinca
2026-03-21 01:47:52 +01:00
committed by GitHub
parent 130cc1f63c
commit 9320a5e252

View File

@@ -1449,17 +1449,23 @@
"\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for _ in range(max_new_tokens):\n",
" out = model(token_ids)[:, -1]\n",
" next_token = torch.argmax(out, dim=-1, keepdim=True)\n",
" cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
" model.reset_kv_cache()\n",
"\n",
" if (eos_token_id is not None\n",
" and torch.all(next_token == eos_token_id)):\n",
" break\n",
" # Prime the cache with the initial context\n",
" logits = model(token_ids, cache=cache)\n",
"\n",
" for _ in range(max_new_tokens):\n",
" next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
"\n",
" if eos_token_id is not None and torch.all(next_token == eos_token_id):\n",
" break\n",
"\n",
" yield next_token\n",
" \n",
" token_ids = torch.cat([token_ids, next_token], dim=1)"
"\n",
" token_ids = torch.cat([token_ids, next_token], dim=1)\n",
" # Feed only the new token to the model; cache handles history\n",
" logits = model(next_token, cache=cache)"
]
},
{