From 32965e0edde230fc835947f0a0e61f800acc90c7 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 11 Sep 2025 15:16:08 -0500 Subject: [PATCH] remove redundant next_cache (#817) --- ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb | 2 -- ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb | 2 -- pkg/llms_from_scratch/kv_cache/gpt2.py | 2 -- pkg/llms_from_scratch/kv_cache/llama3.py | 2 -- pkg/llms_from_scratch/kv_cache/qwen3.py | 2 -- 5 files changed, 10 deletions(-) diff --git a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb index 792c35c..0b72499 100644 --- a/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-moe-plus-kvcache.ipynb @@ -496,7 +496,6 @@ " # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n", " mask = mask[None, None, :, :]\n", "\n", - " next_cache = []\n", " for i, block in enumerate(self.trf_blocks):\n", " blk_cache = cache.get(i) if cache else None\n", " x, new_blk_cache = block(x, mask, self.cos, self.sin,\n", @@ -504,7 +503,6 @@ " cache=blk_cache)\n", " if cache is not None:\n", " cache.update(i, new_blk_cache)\n", - " next_cache.append(new_blk_cache)\n", "\n", " x = self.final_norm(x)\n", " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", diff --git a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb index 9b0f987..2452535 100644 --- a/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb +++ b/ch05/11_qwen3/standalone-qwen3-plus-kvcache.ipynb @@ -422,7 +422,6 @@ " # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n", " mask = mask[None, None, :, :]\n", "\n", - " next_cache = []\n", " for i, block in enumerate(self.trf_blocks):\n", " blk_cache = cache.get(i) if cache else None\n", " x, new_blk_cache = block(x, mask, self.cos, self.sin,\n", @@ -430,7 +429,6 @@ " cache=blk_cache)\n", " if cache is not None:\n", " cache.update(i, new_blk_cache)\n", - " next_cache.append(new_blk_cache)\n", "\n", " x = self.final_norm(x)\n", " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", diff --git a/pkg/llms_from_scratch/kv_cache/gpt2.py b/pkg/llms_from_scratch/kv_cache/gpt2.py index bc9e4dd..adb7b67 100644 --- a/pkg/llms_from_scratch/kv_cache/gpt2.py +++ b/pkg/llms_from_scratch/kv_cache/gpt2.py @@ -177,13 +177,11 @@ class GPTModel(nn.Module): else: start_pos = 0 - next_cache = [] for i, block in enumerate(self.trf_blocks): blk_cache = cache.get(i) if cache else None x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache) if cache: cache.update(i, new_cache) - next_cache.append(new_cache) x = self.final_norm(x) logits = self.out_head(x) diff --git a/pkg/llms_from_scratch/kv_cache/llama3.py b/pkg/llms_from_scratch/kv_cache/llama3.py index 70258d0..098dee4 100644 --- a/pkg/llms_from_scratch/kv_cache/llama3.py +++ b/pkg/llms_from_scratch/kv_cache/llama3.py @@ -97,7 +97,6 @@ class Llama3Model(nn.Module): # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads mask = mask[None, None, :, :] - next_cache = [] for i, block in enumerate(self.trf_blocks): blk_cache = cache.get(i) if cache else None x, new_blk_cache = block(x, mask, self.cos, self.sin, @@ -105,7 +104,6 @@ class Llama3Model(nn.Module): cache=blk_cache) if cache is not None: cache.update(i, new_blk_cache) - next_cache.append(new_blk_cache) x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"])) diff --git a/pkg/llms_from_scratch/kv_cache/qwen3.py b/pkg/llms_from_scratch/kv_cache/qwen3.py index 652e5fe..4960234 100644 --- a/pkg/llms_from_scratch/kv_cache/qwen3.py +++ b/pkg/llms_from_scratch/kv_cache/qwen3.py @@ -65,7 +65,6 @@ class Qwen3Model(nn.Module): # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads mask = mask[None, None, :, :] - next_cache = [] for i, block in enumerate(self.trf_blocks): blk_cache = cache.get(i) if cache else None x, new_blk_cache = block(x, mask, self.cos, self.sin, @@ -73,7 +72,6 @@ class Qwen3Model(nn.Module): cache=blk_cache) if cache is not None: cache.update(i, new_blk_cache) - next_cache.append(new_blk_cache) x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"]))