From 9320a5e25216e8ea4b1d6528081e2edbf1bd50f2 Mon Sep 17 00:00:00 2001 From: casinca <47400729+casinca@users.noreply.github.com> Date: Sat, 21 Mar 2026 01:47:52 +0100 Subject: [PATCH] fix: added KVcache in `generate_text_basic_stream` (#981) --- ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb index 929de3d..25529de 100644 --- a/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb +++ b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb @@ -1449,17 +1449,23 @@ "\n", " model.eval()\n", " with torch.no_grad():\n", - " for _ in range(max_new_tokens):\n", - " out = model(token_ids)[:, -1]\n", - " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + " cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n", + " model.reset_kv_cache()\n", "\n", - " if (eos_token_id is not None\n", - " and torch.all(next_token == eos_token_id)):\n", - " break\n", + " # Prime the cache with the initial context\n", + " logits = model(token_ids, cache=cache)\n", + "\n", + " for _ in range(max_new_tokens):\n", + " next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n", + "\n", + " if eos_token_id is not None and torch.all(next_token == eos_token_id):\n", + " break\n", "\n", " yield next_token\n", - " \n", - " token_ids = torch.cat([token_ids, next_token], dim=1)" + "\n", + " token_ids = torch.cat([token_ids, next_token], dim=1)\n", + " # Feed only the new token to the model; cache handles history\n", + " logits = model(next_token, cache=cache)" ] }, {