mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Numerically stable generate on mps (#849)
* Numerically stable generate on mps * add file
This commit is contained in:
committed by
GitHub
parent
f492c949d3
commit
9bc827ea7e
@@ -1933,7 +1933,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- The previous two subsections introduced temperature sampling and top-k sampling\n",
|
"- The previous two subsections introduced temperature sampling and top-k sampling\n",
|
||||||
"- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:"
|
"- Let's use these two concepts to modify the `generate_text_simple` function from chapter 4, creating a new `generate` function:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1963,6 +1963,10 @@
|
|||||||
" if temperature > 0.0:\n",
|
" if temperature > 0.0:\n",
|
||||||
" logits = logits / temperature\n",
|
" logits = logits / temperature\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # New (not in book): numerical stability tip to get equivalent results on mps device\n",
|
||||||
|
" # subtract rowwise max before softmax\n",
|
||||||
|
" logits = logits - logits.max(dim=-1, keepdim=True).values\n",
|
||||||
|
" \n",
|
||||||
" # Apply softmax to get probabilities\n",
|
" # Apply softmax to get probabilities\n",
|
||||||
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -235,6 +235,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
|||||||
@@ -1047,6 +1047,10 @@
|
|||||||
" if temperature > 0.0:\n",
|
" if temperature > 0.0:\n",
|
||||||
" logits = logits / temperature\n",
|
" logits = logits / temperature\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" # New (not in book): numerical stability tip to get equivalent results on mps device\n",
|
||||||
|
" # subtract rowwise max before softmax\n",
|
||||||
|
" logits = logits - logits.max(dim=-1, keepdim=True).values\n",
|
||||||
|
" \n",
|
||||||
" # Apply softmax to get probabilities\n",
|
" # Apply softmax to get probabilities\n",
|
||||||
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
" probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -334,6 +334,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
|||||||
@@ -267,6 +267,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
|||||||
@@ -267,8 +267,13 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
#logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
#probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
probs = torch.log_softmax(logits, dim=-1)
|
||||||
|
|
||||||
# Sample from the distribution
|
# Sample from the distribution
|
||||||
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
||||||
|
|||||||
@@ -36,6 +36,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
|||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
logits = logits / temperature
|
logits = logits / temperature
|
||||||
|
|
||||||
|
# New (not in book): numerical stability tip to get equivalent results on mps device
|
||||||
|
# subtract rowwise max before softmax
|
||||||
|
logits = logits - logits.max(dim=-1, keepdim=True).values
|
||||||
|
|
||||||
# Apply softmax to get probabilities
|
# Apply softmax to get probabilities
|
||||||
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user