remove redundant next_cache (#817)

This commit is contained in:
Sebastian Raschka
2025-09-11 15:16:08 -05:00
committed by GitHub
parent c7a4362ca4
commit 32965e0edd
5 changed files with 0 additions and 10 deletions

View File

@@ -177,13 +177,11 @@ class GPTModel(nn.Module):
else:
start_pos = 0
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
if cache:
cache.update(i, new_cache)
next_cache.append(new_cache)
x = self.final_norm(x)
logits = self.out_head(x)

View File

@@ -97,7 +97,6 @@ class Llama3Model(nn.Module):
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
@@ -105,7 +104,6 @@ class Llama3Model(nn.Module):
cache=blk_cache)
if cache is not None:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))

View File

@@ -65,7 +65,6 @@ class Qwen3Model(nn.Module):
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
mask = mask[None, None, :, :]
next_cache = []
for i, block in enumerate(self.trf_blocks):
blk_cache = cache.get(i) if cache else None
x, new_blk_cache = block(x, mask, self.cos, self.sin,
@@ -73,7 +72,6 @@ class Qwen3Model(nn.Module):
cache=blk_cache)
if cache is not None:
cache.update(i, new_blk_cache)
next_cache.append(new_blk_cache)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))