mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Numerically stable generate on mps (#849)
* Numerically stable generate on mps * add file
This commit is contained in:
committed by
GitHub
parent
f492c949d3
commit
9bc827ea7e
@@ -36,6 +36,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user