From b8e12e1dd1da48abac4cb2d2019237e4582a36b5 Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 9 Oct 2025 10:59:17 -0500 Subject: [PATCH] Use inference_device --- ch05/01_main-chapter-code/ch05.ipynb | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ch05/01_main-chapter-code/ch05.ipynb b/ch05/01_main-chapter-code/ch05.ipynb index 22809c7..7c6a56f 100644 --- a/ch05/01_main-chapter-code/ch05.ipynb +++ b/ch05/01_main-chapter-code/ch05.ipynb @@ -1519,14 +1519,19 @@ } ], "source": [ - "model.to(\"cpu\")\n", + "# NEW: use CPU here as inference is cheap with \n", + "# this model and to ensure readers get same results in the\n", + "# remaining sections of this book\n", + "inference_device = torch.device(\"cpu\")\n", + "\n", + "model.to(inference_device)\n", "model.eval()\n", "\n", "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", "\n", "token_ids = generate_text_simple(\n", " model=model,\n", - " idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", + " idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n", " max_new_tokens=25,\n", " context_size=GPT_CONFIG_124M[\"context_length\"]\n", ")\n", @@ -2030,7 +2035,7 @@ "\n", "token_ids = generate(\n", " model=model,\n", - " idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n", + " idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(inference_device),\n", " max_new_tokens=15,\n", " context_size=GPT_CONFIG_124M[\"context_length\"],\n", " top_k=25,\n",