mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Update generate script (#847)
* Custom python 3.13 entry in pyproject.toml * amend * Update generate script * update * Update pyproject.toml
This commit is contained in:
committed by
GitHub
parent
9bc827ea7e
commit
47867bc1cb
@@ -3,6 +3,7 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
@@ -258,9 +259,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
return idx
|
||||
|
||||
|
||||
def main(gpt_config, input_prompt, model_size):
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
def main(gpt_config, input_prompt, model_size, device):
|
||||
|
||||
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
|
||||
|
||||
@@ -286,10 +285,30 @@ def main(gpt_config, input_prompt, model_size):
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(description="Generate text with a pretrained GPT-2 model.")
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
default="Every effort moves you",
|
||||
help="Prompt text used to seed the generation (default matches the script's built-in prompt)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cpu",
|
||||
help="Device for running inference, e.g., cpu, cuda, mps, or auto. Defaults to cpu."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
torch.manual_seed(123)
|
||||
|
||||
CHOOSE_MODEL = "gpt2-small (124M)"
|
||||
INPUT_PROMPT = "Every effort moves you"
|
||||
INPUT_PROMPT = args.prompt
|
||||
DEVICE = torch.device(args.device)
|
||||
|
||||
print("PyTorch:", torch.__version__)
|
||||
print("Device:", DEVICE)
|
||||
|
||||
|
||||
BASE_CONFIG = {
|
||||
"vocab_size": 50257, # Vocabulary size
|
||||
@@ -309,4 +328,4 @@ if __name__ == "__main__":
|
||||
|
||||
BASE_CONFIG.update(model_configs[CHOOSE_MODEL])
|
||||
|
||||
main(BASE_CONFIG, INPUT_PROMPT, model_size)
|
||||
main(BASE_CONFIG, INPUT_PROMPT, model_size, DEVICE)
|
||||
|
||||
Reference in New Issue
Block a user