mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
remove redundant next_cache (#817)
This commit is contained in:
committed by
GitHub
parent
c7a4362ca4
commit
32965e0edd
@@ -496,7 +496,6 @@
|
||||
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
|
||||
" mask = mask[None, None, :, :]\n",
|
||||
"\n",
|
||||
" next_cache = []\n",
|
||||
" for i, block in enumerate(self.trf_blocks):\n",
|
||||
" blk_cache = cache.get(i) if cache else None\n",
|
||||
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
|
||||
@@ -504,7 +503,6 @@
|
||||
" cache=blk_cache)\n",
|
||||
" if cache is not None:\n",
|
||||
" cache.update(i, new_blk_cache)\n",
|
||||
" next_cache.append(new_blk_cache)\n",
|
||||
"\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
||||
|
||||
@@ -422,7 +422,6 @@
|
||||
" # Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads\n",
|
||||
" mask = mask[None, None, :, :]\n",
|
||||
"\n",
|
||||
" next_cache = []\n",
|
||||
" for i, block in enumerate(self.trf_blocks):\n",
|
||||
" blk_cache = cache.get(i) if cache else None\n",
|
||||
" x, new_blk_cache = block(x, mask, self.cos, self.sin,\n",
|
||||
@@ -430,7 +429,6 @@
|
||||
" cache=blk_cache)\n",
|
||||
" if cache is not None:\n",
|
||||
" cache.update(i, new_blk_cache)\n",
|
||||
" next_cache.append(new_blk_cache)\n",
|
||||
"\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
||||
|
||||
@@ -177,13 +177,11 @@ class GPTModel(nn.Module):
|
||||
else:
|
||||
start_pos = 0
|
||||
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
|
||||
if cache:
|
||||
cache.update(i, new_cache)
|
||||
next_cache.append(new_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x)
|
||||
|
||||
@@ -97,7 +97,6 @@ class Llama3Model(nn.Module):
|
||||
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
|
||||
mask = mask[None, None, :, :]
|
||||
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||
@@ -105,7 +104,6 @@ class Llama3Model(nn.Module):
|
||||
cache=blk_cache)
|
||||
if cache is not None:
|
||||
cache.update(i, new_blk_cache)
|
||||
next_cache.append(new_blk_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
|
||||
@@ -65,7 +65,6 @@ class Qwen3Model(nn.Module):
|
||||
# Shape (1, 1, num_tokens, num_tokens) to broadcast across batch and heads
|
||||
mask = mask[None, None, :, :]
|
||||
|
||||
next_cache = []
|
||||
for i, block in enumerate(self.trf_blocks):
|
||||
blk_cache = cache.get(i) if cache else None
|
||||
x, new_blk_cache = block(x, mask, self.cos, self.sin,
|
||||
@@ -73,7 +72,6 @@ class Qwen3Model(nn.Module):
|
||||
cache=blk_cache)
|
||||
if cache is not None:
|
||||
cache.update(i, new_blk_cache)
|
||||
next_cache.append(new_blk_cache)
|
||||
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
|
||||
Reference in New Issue
Block a user