mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Numerically stable generate on mps (#849)
* Numerically stable generate on mps * add file
This commit is contained in:
committed by
GitHub
parent
f492c949d3
commit
9bc827ea7e
@@ -267,6 +267,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -267,8 +267,13 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
#logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
#probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
probs = torch.log_softmax(logits, dim=-1)
|
||||
|
||||
# Sample from the distribution
|
||||
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
||||
|
||||
Reference in New Issue
Block a user