mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
add eos_id option for ch07
This commit is contained in:
@@ -215,7 +215,7 @@ def load_weights_into_gpt(gpt, params):
|
||||
gpt.out_head.weight = assign(gpt.out_head.weight, params["wte"])
|
||||
|
||||
|
||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
||||
def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None, eos_id=None):
|
||||
|
||||
# For-loop is the same as before: Get logits, and only focus on last time step
|
||||
for _ in range(max_new_tokens):
|
||||
@@ -245,6 +245,9 @@ def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):
|
||||
else:
|
||||
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)
|
||||
|
||||
if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified
|
||||
break
|
||||
|
||||
# Same as before: append sampled index to the running sequence
|
||||
idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user