Alt weight loading code via PyTorch (#585)

* Alt weight loading code via PyTorch

* commit additional files
This commit is contained in:
Sebastian Raschka
2025-03-27 20:10:23 -05:00
committed by GitHub
parent ffd4035144
commit 3f93d73d6d
7 changed files with 535 additions and 18 deletions

View File

@@ -2133,20 +2133,53 @@
"id": "127ddbdb-3878-4669-9a39-d231fbdfb834",
"metadata": {},
"source": [
"<span style=\"color:darkred\">\n",
" <ul>\n",
" <li>For an alternative way to load the weights from the Hugging Face Hub, see <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a></li>\n",
" <ul>\n",
" <li>This is useful if:</li>\n",
" <ul>\n",
" <li>the weights are temporarily unavailable</li>\n",
" <li>a company VPN only permits downloads from the Hugging Face Hub but not from the OpenAI CDN, for example</li>\n",
" <li>you are having trouble with the TensorFlow installation (the original weights are stored in TensorFlow files)</li>\n",
" </ul>\n",
" </ul>\n",
" <li>The <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a> code notebooks are replacements for the remainder of this section 5.5</li>\n",
" </ul>\n",
"</span>\n"
"---\n",
"\n",
"---\n",
"\n",
"\n",
"⚠️ **Note: Some users may encounter issues in this section due to TensorFlow compatibility problems, particularly on certain Windows systems. TensorFlow is required here only to load the original OpenAI GPT-2 weight files, which we then convert to PyTorch.\n",
"If you're running into TensorFlow-related issues, you can use the alternative code below instead of the remaining code in this section.\n",
"This alternative is based on pre-converted PyTorch weights, created using the same conversion process described in the previous section. For details, refer to the notebook:\n",
"[../02_alternative_weight_loading/weight-loading-pytorch.ipynb](../02_alternative_weight_loading/weight-loading-pytorch.ipynb) notebook.**\n",
"\n",
"```python\n",
"file_name = \"gpt2-small-124M.pth\"\n",
"# file_name = \"gpt2-medium-355M.pth\"\n",
"# file_name = \"gpt2-large-774M.pth\"\n",
"# file_name = \"gpt2-xl-1558M.pth\"\n",
"\n",
"url = f\"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}\"\n",
"\n",
"if not os.path.exists(file_name):\n",
" urllib.request.urlretrieve(url, file_name)\n",
" print(f\"Downloaded to {file_name}\")\n",
"\n",
"gpt = GPTModel(BASE_CONFIG)\n",
"gpt.load_state_dict(torch.load(file_name, weights_only=True))\n",
"gpt.eval()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"gpt.to(device);\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"token_ids = generate(\n",
" model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
" max_new_tokens=25,\n",
" context_size=NEW_CONFIG[\"context_length\"],\n",
" top_k=50,\n",
" temperature=1.5\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))\n",
"```\n",
"\n",
"---\n",
"\n",
"---"
]
},
{
@@ -2197,7 +2230,10 @@
"outputs": [],
"source": [
"# Relative import from the gpt_download.py contained in this folder\n",
"from gpt_download import download_and_load_gpt2"
"\n",
"from gpt_download import download_and_load_gpt2\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch05 import download_and_load_gpt2"
]
},
{