mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Alt weight loading code via PyTorch (#585)
* Alt weight loading code via PyTorch * commit additional files
This commit is contained in:
committed by
GitHub
parent
e07a7abdd5
commit
e55e3e88e1
@@ -2133,20 +2133,53 @@
|
||||
"id": "127ddbdb-3878-4669-9a39-d231fbdfb834",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<span style=\"color:darkred\">\n",
|
||||
" <ul>\n",
|
||||
" <li>For an alternative way to load the weights from the Hugging Face Hub, see <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a></li>\n",
|
||||
" <ul>\n",
|
||||
" <li>This is useful if:</li>\n",
|
||||
" <ul>\n",
|
||||
" <li>the weights are temporarily unavailable</li>\n",
|
||||
" <li>a company VPN only permits downloads from the Hugging Face Hub but not from the OpenAI CDN, for example</li>\n",
|
||||
" <li>you are having trouble with the TensorFlow installation (the original weights are stored in TensorFlow files)</li>\n",
|
||||
" </ul>\n",
|
||||
" </ul>\n",
|
||||
" <li>The <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a> code notebooks are replacements for the remainder of this section 5.5</li>\n",
|
||||
" </ul>\n",
|
||||
"</span>\n"
|
||||
"---\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"⚠️ **Note: Some users may encounter issues in this section due to TensorFlow compatibility problems, particularly on certain Windows systems. TensorFlow is required here only to load the original OpenAI GPT-2 weight files, which we then convert to PyTorch.\n",
|
||||
"If you're running into TensorFlow-related issues, you can use the alternative code below instead of the remaining code in this section.\n",
|
||||
"This alternative is based on pre-converted PyTorch weights, created using the same conversion process described in the previous section. For details, refer to the notebook:\n",
|
||||
"[../02_alternative_weight_loading/weight-loading-pytorch.ipynb](../02_alternative_weight_loading/weight-loading-pytorch.ipynb) notebook.**\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"file_name = \"gpt2-small-124M.pth\"\n",
|
||||
"# file_name = \"gpt2-medium-355M.pth\"\n",
|
||||
"# file_name = \"gpt2-large-774M.pth\"\n",
|
||||
"# file_name = \"gpt2-xl-1558M.pth\"\n",
|
||||
"\n",
|
||||
"url = f\"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}\"\n",
|
||||
"\n",
|
||||
"if not os.path.exists(file_name):\n",
|
||||
" urllib.request.urlretrieve(url, file_name)\n",
|
||||
" print(f\"Downloaded to {file_name}\")\n",
|
||||
"\n",
|
||||
"gpt = GPTModel(BASE_CONFIG)\n",
|
||||
"gpt.load_state_dict(torch.load(file_name, weights_only=True))\n",
|
||||
"gpt.eval()\n",
|
||||
"\n",
|
||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"gpt.to(device);\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"token_ids = generate(\n",
|
||||
" model=gpt,\n",
|
||||
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
|
||||
" max_new_tokens=25,\n",
|
||||
" context_size=NEW_CONFIG[\"context_length\"],\n",
|
||||
" top_k=50,\n",
|
||||
" temperature=1.5\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"---\n",
|
||||
"\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2197,7 +2230,10 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Relative import from the gpt_download.py contained in this folder\n",
|
||||
"from gpt_download import download_and_load_gpt2"
|
||||
"\n",
|
||||
"from gpt_download import download_and_load_gpt2\n",
|
||||
"# Alternatively:\n",
|
||||
"# from llms_from_scratch.ch05 import download_and_load_gpt2"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user