Numerically stable generate on mps (#849)

* Numerically stable generate on mps

* add file
This commit is contained in:
Sebastian Raschka
2025-09-26 22:42:44 -05:00
committed by GitHub
parent f492c949d3
commit 9bc827ea7e
8 changed files with 35 additions and 2 deletions

View File

@@ -235,6 +235,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
if temperature > 0.0:
logits = logits / temperature
# New (not in book): numerical stability tip to get equivalent results on mps device
# subtract rowwise max before softmax
logits = logits - logits.max(dim=-1, keepdim=True).values
# Apply softmax to get probabilities
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)