Alternative weight loading via .safetensors (#507)

This commit is contained in:
Sebastian Raschka
2025-01-29 08:15:29 -06:00
committed by GitHub
parent 9daa7e7511
commit 25ea71e713
6 changed files with 336 additions and 6 deletions

View File

@@ -155,8 +155,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@@ -229,7 +229,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0: