From 47867bc1cb60e714f892cff40442e01efa53a322 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Sat, 27 Sep 2025 08:03:54 -0500 Subject: [PATCH] Update generate script (#847) * Custom python 3.13 entry in pyproject.toml * amend * Update generate script * update * Update pyproject.toml --- ch05/01_main-chapter-code/gpt_generate.py | 29 +++++++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index 337059b..b68d170 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -3,6 +3,7 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch +import argparse import json import numpy as np import os @@ -258,9 +259,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No return idx -def main(gpt_config, input_prompt, model_size): - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def main(gpt_config, input_prompt, model_size, device): settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") @@ -286,10 +285,30 @@ def main(gpt_config, input_prompt, model_size): if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate text with a pretrained GPT-2 model.") + parser.add_argument( + "--prompt", + default="Every effort moves you", + help="Prompt text used to seed the generation (default matches the script's built-in prompt)." + ) + parser.add_argument( + "--device", + default="cpu", + help="Device for running inference, e.g., cpu, cuda, mps, or auto. Defaults to cpu." + ) + + args = parser.parse_args() + + torch.manual_seed(123) CHOOSE_MODEL = "gpt2-small (124M)" - INPUT_PROMPT = "Every effort moves you" + INPUT_PROMPT = args.prompt + DEVICE = torch.device(args.device) + + print("PyTorch:", torch.__version__) + print("Device:", DEVICE) + BASE_CONFIG = { "vocab_size": 50257, # Vocabulary size @@ -309,4 +328,4 @@ if __name__ == "__main__": BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) - main(BASE_CONFIG, INPUT_PROMPT, model_size) + main(BASE_CONFIG, INPUT_PROMPT, model_size, DEVICE)