mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add flexible padding bonus experiment (#438)
* Add flexible padding bonus experiment * fix links
This commit is contained in:
committed by
GitHub
parent
f6281ab91b
commit
ccade77bf4
@@ -184,16 +184,34 @@ def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
if trainable_token_pos == "flexible": # Selects the last tokens before the padding tokens
|
||||
# From https://github.com/rasbt/LLMs-from-scratch/discussions/434
|
||||
# Find the last non-padding token for each sequence in the batch
|
||||
pad_token_id = 50256 # <|endoftext|> token used for padding
|
||||
mask = input_batch != pad_token_id
|
||||
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
|
||||
return loss
|
||||
# Get model outputs
|
||||
logits = model(input_batch) # shape: [batch_size, seq_len, num_classes]
|
||||
|
||||
# Select the logits corresponding to the last real token of each sequence
|
||||
batch_size = logits.size(0)
|
||||
selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(selected_logits, target_batch)
|
||||
return loss
|
||||
|
||||
else:
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
|
||||
return loss
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
@@ -231,24 +249,48 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
if trainable_token_pos == "flexible":
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
# Find the last non-padding token for each sequence in the batch
|
||||
pad_token_id = 50256 # <|endoftext|> token used for padding
|
||||
mask = input_batch != pad_token_id
|
||||
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_batch) # Logits of last output token
|
||||
# Select the logits corresponding to the last real token of each sequence
|
||||
batch_size = logits.size(0)
|
||||
selected_logits = logits[torch.arange(batch_size), last_token_pos]
|
||||
predicted_labels = torch.argmax(selected_logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
break
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
else:
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
break
|
||||
model_output = model(input_batch)
|
||||
if average_embeddings:
|
||||
# Average over the sequence dimension (dim=1)
|
||||
logits = model_output.mean(dim=1)
|
||||
else:
|
||||
# Select embeddings at the specified token position
|
||||
logits = model_output[:, trainable_token_pos, :]
|
||||
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
correct_predictions += (predicted_labels == target_batch).sum().item()
|
||||
else:
|
||||
break
|
||||
return correct_predictions / num_examples
|
||||
|
||||
|
||||
@@ -386,7 +428,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="last",
|
||||
help=(
|
||||
"Which token position to train. Options: 'first', 'last'."
|
||||
"Which token position to train. Options: 'first', 'last', 'flexible'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -483,6 +525,10 @@ if __name__ == "__main__":
|
||||
args.trainable_token_pos = 0
|
||||
elif args.trainable_token_pos == "last":
|
||||
args.trainable_token_pos = -1
|
||||
# The "flexible" setting selects the last tokens before the padding tokens
|
||||
# See https://github.com/rasbt/LLMs-from-scratch/discussions/434
|
||||
elif args.trainable_token_pos == "flexible":
|
||||
args.trainable_token_pos = "flexible"
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_token_pos argument")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user