Numerically stable generate on mps (#849)

* Numerically stable generate on mps

* add file
This commit is contained in:
Sebastian Raschka
2025-09-26 22:42:44 -05:00
committed by GitHub
parent f492c949d3
commit 9bc827ea7e
8 changed files with 35 additions and 2 deletions

View File

@@ -1933,7 +1933,7 @@
"metadata": {},
"source": [
"- The previous two subsections introduced temperature sampling and top-k sampling\n",
"- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:"
"- Let's use these two concepts to modify the `generate_text_simple` function from chapter 4, creating a new `generate` function:"
]
},
{
@@ -1963,6 +1963,10 @@
" if temperature > 0.0:\n",
" logits = logits / temperature\n",
"\n",
" # New (not in book): numerical stability tip to get equivalent results on mps device\n",
" # subtract rowwise max before softmax\n",
" logits = logits - logits.max(dim=-1, keepdim=True).values\n",
" \n",
" # Apply softmax to get probabilities\n",
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
"\n",