diff --git a/ch05/01_main-chapter-code/gpt_generate.py b/ch05/01_main-chapter-code/gpt_generate.py index 337059b..b68d170 100644 --- a/ch05/01_main-chapter-code/gpt_generate.py +++ b/ch05/01_main-chapter-code/gpt_generate.py @@ -3,6 +3,7 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch +import argparse import json import numpy as np import os @@ -258,9 +259,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No return idx -def main(gpt_config, input_prompt, model_size): - - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def main(gpt_config, input_prompt, model_size, device): settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2") @@ -286,10 +285,30 @@ def main(gpt_config, input_prompt, model_size): if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate text with a pretrained GPT-2 model.") + parser.add_argument( + "--prompt", + default="Every effort moves you", + help="Prompt text used to seed the generation (default matches the script's built-in prompt)." + ) + parser.add_argument( + "--device", + default="cpu", + help="Device for running inference, e.g., cpu, cuda, mps, or auto. Defaults to cpu." + ) + + args = parser.parse_args() + + torch.manual_seed(123) CHOOSE_MODEL = "gpt2-small (124M)" - INPUT_PROMPT = "Every effort moves you" + INPUT_PROMPT = args.prompt + DEVICE = torch.device(args.device) + + print("PyTorch:", torch.__version__) + print("Device:", DEVICE) + BASE_CONFIG = { "vocab_size": 50257, # Vocabulary size @@ -309,4 +328,4 @@ if __name__ == "__main__": BASE_CONFIG.update(model_configs[CHOOSE_MODEL]) - main(BASE_CONFIG, INPUT_PROMPT, model_size) + main(BASE_CONFIG, INPUT_PROMPT, model_size, DEVICE)