mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Numerically stable generate on mps
This commit is contained in:
@@ -1933,7 +1933,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The previous two subsections introduced temperature sampling and top-k sampling\n",
|
||||
"- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:"
|
||||
"- Let's use these two concepts to modify the `generate_text_simple` function from chapter 4, creating a new `generate` function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -1963,6 +1963,10 @@
|
||||
" if temperature > 0.0:\n",
|
||||
" logits = logits / temperature\n",
|
||||
"\n",
|
||||
" # New (not in book): numerical stability tip to get equivalent results on mps device\n",
|
||||
" # subtract rowwise max before softmax\n",
|
||||
" logits = logits - logits.max(dim=-1, keepdim=True).values\n",
|
||||
" \n",
|
||||
" # Apply softmax to get probabilities\n",
|
||||
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
||||
"\n",
|
||||
|
||||
@@ -235,6 +235,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -44,6 +44,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -1047,6 +1047,10 @@
|
||||
" if temperature > 0.0:\n",
|
||||
" logits = logits / temperature\n",
|
||||
"\n",
|
||||
" # New (not in book): numerical stability tip to get equivalent results on mps device\n",
|
||||
" # subtract rowwise max before softmax\n",
|
||||
" logits = logits - logits.max(dim=-1, keepdim=True).values\n",
|
||||
" \n",
|
||||
" # Apply softmax to get probabilities\n",
|
||||
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
||||
"\n",
|
||||
|
||||
@@ -334,6 +334,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -267,6 +267,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -267,6 +267,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
@@ -36,6 +36,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
if temperature > 0.0:
|
||||
logits = logits / temperature
|
||||
|
||||
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||
# subtract rowwise max before softmax
|
||||
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||
|
||||
# Apply softmax to get probabilities
|
||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user