remove redundant next_cache (#817)

This commit is contained in:
Sebastian Raschka
2025-09-11 15:16:08 -05:00
committed by GitHub
parent c7a4362ca4
commit 32965e0edd
5 changed files with 0 additions and 10 deletions

View File

@@ -496,7 +496,6 @@
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
" mask = mask[None, None, :, :]\n",
"\n",
" next_cache = []\n",
" for i, block in enumerate(self.trf_blocks):\n",
" blk_cache = cache.get(i) if cache else None\n",
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
@@ -504,7 +503,6 @@
" cache=blk_cache)\n",
" if cache is not None:\n",
" cache.update(i, new_blk_cache)\n",
" next_cache.append(new_blk_cache)\n",
"\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",

View File

@@ -422,7 +422,6 @@
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
" mask = mask[None, None, :, :]\n",
"\n",
" next_cache = []\n",
" for i, block in enumerate(self.trf_blocks):\n",
" blk_cache = cache.get(i) if cache else None\n",
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
@@ -430,7 +429,6 @@
" cache=blk_cache)\n",
" if cache is not None:\n",
" cache.update(i, new_blk_cache)\n",
" next_cache.append(new_blk_cache)\n",
"\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",

View File

@@ -177,13 +177,11 @@ class GPTModel(nn.Module):
else:
start_pos = 0
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
if cache:
cache.update(i, new_cache)
next_cache.append(new_cache)
x = self.final_norm(x)
logits = self.out_head(x)

View File

@@ -97,7 +97,6 @@ class Llama3Model(nn.Module):
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
@@ -105,7 +104,6 @@ class Llama3Model(nn.Module):
cache=blk_cache)
if cache is not None:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))

View File

@@ -65,7 +65,6 @@ class Qwen3Model(nn.Module):
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
@@ -73,7 +72,6 @@ class Qwen3Model(nn.Module):
cache=blk_cache)
if cache is not None:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))