add eos_id option for ch07

This commit is contained in:
rasbt
2024-05-18 12:35:40 -05:00
parent 3b57b6d8c4
commit 4851d5a0fa
4 changed files with 18 additions and 6 deletions

View File

@@ -215,7 +215,7 @@ def load_weights_into_gpt(gpt, params):
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
# For-loop is the same as before: Get logits, and only focus on last time step
for _ in range(max_new_tokens):
@@ -245,6 +245,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
else:
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
break
# Same as before: append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)