mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
fix: added KVcache in generate_text_basic_stream (#981)
This commit is contained in:
@@ -1449,17 +1449,23 @@
|
|||||||
"\n",
|
"\n",
|
||||||
" model.eval()\n",
|
" model.eval()\n",
|
||||||
" with torch.no_grad():\n",
|
" with torch.no_grad():\n",
|
||||||
" for _ in range(max_new_tokens):\n",
|
" cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
|
||||||
" out = model(token_ids)[:, -1]\n",
|
" model.reset_kv_cache()\n",
|
||||||
" next_token = torch.argmax(out, dim=-1, keepdim=True)\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
" if (eos_token_id is not None\n",
|
" # Prime the cache with the initial context\n",
|
||||||
" and torch.all(next_token == eos_token_id)):\n",
|
" logits = model(token_ids, cache=cache)\n",
|
||||||
" break\n",
|
"\n",
|
||||||
|
" for _ in range(max_new_tokens):\n",
|
||||||
|
" next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
|
||||||
|
"\n",
|
||||||
|
" if eos_token_id is not None and torch.all(next_token == eos_token_id):\n",
|
||||||
|
" break\n",
|
||||||
"\n",
|
"\n",
|
||||||
" yield next_token\n",
|
" yield next_token\n",
|
||||||
" \n",
|
"\n",
|
||||||
" token_ids = torch.cat([token_ids, next_token], dim=1)"
|
" token_ids = torch.cat([token_ids, next_token], dim=1)\n",
|
||||||
|
" # Feed only the new token to the model; cache handles history\n",
|
||||||
|
" logits = model(next_token, cache=cache)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Reference in New Issue
Block a user