From 9bc827ea7e4f411a3c83fdecf670fd022a97c809 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Fri, 26 Sep 2025 22:42:44 -0500 Subject: [PATCH] Numerically stable generate on mps (#849) * Numerically stable generate on mps * add file --- ch05/01_main-chapter-code/ch05.ipynb | 6 +++++- ch05/01_main-chapter-code/gpt_generate.py | 4 ++++ ch05/07_gpt_to_llama/previous_chapters.py | 4 ++++ ch05/07_gpt_to_llama/standalone-llama32.ipynb | 4 ++++ ch06/02_bonus_additional-experiments/previous_chapters.py | 4 ++++ ch07/01_main-chapter-code/previous_chapters.py | 4 ++++ ch07/04_preference-tuning-with-dpo/previous_chapters.py | 7 ++++++- pkg/llms_from_scratch/ch05.py | 4 ++++ 8 files changed, 35 insertions(+), 2 deletions(-) diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index 7c46f2d..adbbcc2 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1933,7 +1933,7 @@ "metadata": {}, "source": [ "- The previous two subsections introduced temperature sampling and top-k sampling\n", - "- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:" + "- Let's use these two concepts to modify the `generate_text_simple` function from chapter 4, creating a new `generate` function:" ] }, { @@ -1963,6 +1963,10 @@ " if temperature > 0.0:\n", " logits = logits / temperature\n", "\n", + " # New (not in book): numerical stability tip to get equivalent results on mps device\n", + " # subtract rowwise max before softmax\n", + " logits = logits - logits.max(dim=-1, keepdim=True).values\n", + " \n", " # Apply softmax to get probabilities\n", " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", "\n", diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index 4e49ccd..337059b 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -235,6 +235,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) diff --git a/ch05/07_gpt_to_llama/previous_chapters.py b/ch05/07_gpt_to_llama/previous_chapters.py index 1ca678c..93411f5 100644 --- a/ch05/07_gpt_to_llama/previous_chapters.py +++ b/ch05/07_gpt_to_llama/previous_chapters.py @@ -44,6 +44,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 264e23c..4eec667 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -1047,6 +1047,10 @@ " if temperature > 0.0:\n", " logits = logits / temperature\n", "\n", + " # New (not in book): numerical stability tip to get equivalent results on mps device\n", + " # subtract rowwise max before softmax\n", + " logits = logits - logits.max(dim=-1, keepdim=True).values\n", + " \n", " # Apply softmax to get probabilities\n", " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", "\n", diff --git a/ch06/02_bonus_additional-experiments/previous_chapters.py b/ch06/02_bonus_additional-experiments/previous_chapters.py index 46549e9..a4a9baa 100644 --- a/ch06/02_bonus_additional-experiments/previous_chapters.py +++ b/ch06/02_bonus_additional-experiments/previous_chapters.py @@ -334,6 +334,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) diff --git a/ch07/01_main-chapter-code/previous_chapters.py b/ch07/01_main-chapter-code/previous_chapters.py index 090eab5..0aadf9e 100644 --- a/ch07/01_main-chapter-code/previous_chapters.py +++ b/ch07/01_main-chapter-code/previous_chapters.py @@ -267,6 +267,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) diff --git a/ch07/04_preference-tuning-with-dpo/previous_chapters.py b/ch07/04_preference-tuning-with-dpo/previous_chapters.py index bd69339..829e92c 100644 --- a/ch07/04_preference-tuning-with-dpo/previous_chapters.py +++ b/ch07/04_preference-tuning-with-dpo/previous_chapters.py @@ -267,8 +267,13 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + #logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities - probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) + #probs = torch.softmax(logits, dim=-1) # (batch_size, context_len) + probs = torch.log_softmax(logits, dim=-1) # Sample from the distribution idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1) diff --git a/pkg/llms_from_scratch/ch05.py b/pkg/llms_from_scratch/ch05.py index f0ef5d7..315e050 100644 --- a/pkg/llms_from_scratch/ch05.py +++ b/pkg/llms_from_scratch/ch05.py @@ -36,6 +36,10 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No if temperature > 0.0: logits = logits / temperature + # New (not in book): numerical stability tip to get equivalent results on mps device + # subtract rowwise max before softmax + logits = logits - logits.max(dim=-1, keepdim=True).values + # Apply softmax to get probabilities probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)