mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
fix: added KVcache in generate_text_basic_stream (#981)
This commit is contained in:
@@ -1449,17 +1449,23 @@
|
||||
"\n",
|
||||
" model.eval()\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" out = model(token_ids)[:, -1]\n",
|
||||
" next_token = torch.argmax(out, dim=-1, keepdim=True)\n",
|
||||
" cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n",
|
||||
" model.reset_kv_cache()\n",
|
||||
"\n",
|
||||
" if (eos_token_id is not None\n",
|
||||
" and torch.all(next_token == eos_token_id)):\n",
|
||||
" break\n",
|
||||
" # Prime the cache with the initial context\n",
|
||||
" logits = model(token_ids, cache=cache)\n",
|
||||
"\n",
|
||||
" for _ in range(max_new_tokens):\n",
|
||||
" next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n",
|
||||
"\n",
|
||||
" if eos_token_id is not None and torch.all(next_token == eos_token_id):\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" yield next_token\n",
|
||||
" \n",
|
||||
" token_ids = torch.cat([token_ids, next_token], dim=1)"
|
||||
"\n",
|
||||
" token_ids = torch.cat([token_ids, next_token], dim=1)\n",
|
||||
" # Feed only the new token to the model; cache handles history\n",
|
||||
" logits = model(next_token, cache=cache)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user