mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Alternative weight loading via .safetensors (#507)
This commit is contained in:
committed by
GitHub
parent
9daa7e7511
commit
25ea71e713
@@ -2103,7 +2103,20 @@
|
||||
"id": "127ddbdb-3878-4669-9a39-d231fbdfb834",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- For an alternative way to load the weights from the Hugging Face Hub, see [../02_alternative_weight_loading](../02_alternative_weight_loading)"
|
||||
"<span style=\"color:darkred\">\n",
|
||||
" <ul>\n",
|
||||
" <li>For an alternative way to load the weights from the Hugging Face Hub, see <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a></li>\n",
|
||||
" <ul>\n",
|
||||
" <li>This is useful if:</li>\n",
|
||||
" <ul>\n",
|
||||
" <li>the weights are temporarily unavailable</li>\n",
|
||||
" <li>a company VPN only permits downloads from the Hugging Face Hub but not from the OpenAI CDN, for example</li>\n",
|
||||
" <li>you are having trouble with the TensorFlow installation (the original weights are stored in TensorFlow files)</li>\n",
|
||||
" </ul>\n",
|
||||
" </ul>\n",
|
||||
" <li>The <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a> code notebooks are replacements for the remainder of this section 5.5</li>\n",
|
||||
" </ul>\n",
|
||||
"</span>\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -2505,7 +2518,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -155,8 +155,8 @@ def assign(left, right):
|
||||
|
||||
|
||||
def load_weights_into_gpt(gpt, params):
|
||||
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
|
||||
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
|
||||
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
|
||||
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
|
||||
|
||||
for b in range(len(params["blocks"])):
|
||||
q_w, k_w, v_w = np.split(
|
||||
@@ -229,7 +229,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
|
||||
# Keep only top_k values
|
||||
top_logits, _ = torch.topk(logits, top_k)
|
||||
min_val = top_logits[:, -1]
|
||||
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
|
||||
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
|
||||
|
||||
# New: Apply temperature scaling
|
||||
if temperature > 0.0:
|
||||
|
||||
Reference in New Issue
Block a user