Make quote style consistent (#891)

This commit is contained in:
Sebastian Raschka
2025-10-21 19:42:33 -05:00
committed by GitHub
parent 9276edbc37
commit 7ca7c47e4a
24 changed files with 239 additions and 81 deletions

View File

@@ -446,7 +446,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--average_embeddings",
action='store_true',
action="store_true",
default=False,
help=(
"Average the output embeddings from all tokens instead of using"
@@ -480,7 +480,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--no_padding",
action='store_true',
action="store_true",
default=False,
help=(
"Disable padding, which means each example may have a different length."
@@ -517,7 +517,7 @@ if __name__ == "__main__":
)
parser.add_argument(
"--disable_causal_mask",
action='store_true',
action="store_true",
default=False,
help=(
"Disables the causal attention mask."

View File

@@ -74,7 +74,7 @@ class MultiHeadAttention(nn.Module):
self.dropout = nn.Dropout(dropout)
if not disable_causal_mask:
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
self.disable_causal_mask = disable_causal_mask
def forward(self, x):
@@ -255,8 +255,8 @@ def assign(left, right):
def load_weights_into_gpt(gpt, params):
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])
gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"])
gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"])
for b in range(len(params["blocks"])):
q_w, k_w, v_w = np.split(
@@ -328,7 +328,7 @@ def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=No
# Keep only top_k values
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)
logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)
# New: Apply temperature scaling
if temperature > 0.0: