diff --git a/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb index 929de3d..25529de 100644 --- a/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb +++ b/ch05/16_qwen3.5/qwen3.5-plus-kv-cache.ipynb @@ -1449,17 +1449,23 @@ "\n", " model.eval()\n", " with torch.no_grad():\n", - " for _ in range(max_new_tokens):\n", - " out = model(token_ids)[:, -1]\n", - " next_token = torch.argmax(out, dim=-1, keepdim=True)\n", + " cache = KVCache(n_layers=model.cfg[\"n_layers\"])\n", + " model.reset_kv_cache()\n", "\n", - " if (eos_token_id is not None\n", - " and torch.all(next_token == eos_token_id)):\n", - " break\n", + " # Prime the cache with the initial context\n", + " logits = model(token_ids, cache=cache)\n", + "\n", + " for _ in range(max_new_tokens):\n", + " next_token = torch.argmax(logits[:, -1], dim=-1, keepdim=True)\n", + "\n", + " if eos_token_id is not None and torch.all(next_token == eos_token_id):\n", + " break\n", "\n", " yield next_token\n", - " \n", - " token_ids = torch.cat([token_ids, next_token], dim=1)" + "\n", + " token_ids = torch.cat([token_ids, next_token], dim=1)\n", + " # Feed only the new token to the model; cache handles history\n", + " logits = model(next_token, cache=cache)" ] }, {