Add flexible padding bonus experiment (#438)

* Add flexible padding bonus experiment

* fix links
This commit is contained in:
Sebastian Raschka
2024-11-15 08:51:01 +09:00
committed by GitHub
parent f6281ab91b
commit ccade77bf4
3 changed files with 83 additions and 35 deletions

View File

@@ -184,16 +184,34 @@ def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token_pos=-1, ignore_index=-100, average_embeddings=False):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
if trainable_token_pos == "flexible": # Selects the last tokens before the padding tokens
# From https://github.com/rasbt/LLMs-from-scratch/discussions/434
# Find the last non-padding token for each sequence in the batch
pad_token_id = 50256 # <|endoftext|> token used for padding
mask = input_batch != pad_token_id
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss
# Get model outputs
logits = model(input_batch) # shape: [batch_size, seq_len, num_classes]
# Select the logits corresponding to the last real token of each sequence
batch_size = logits.size(0)
selected_logits = logits[torch.arange(batch_size), last_token_pos]
loss = torch.nn.functional.cross_entropy(selected_logits, target_batch)
return loss
else:
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss
def calc_loss_loader(data_loader, model, device,
@@ -231,24 +249,48 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None,
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
if trainable_token_pos == "flexible":
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# Find the last non-padding token for each sequence in the batch
pad_token_id = 50256 # <|endoftext|> token used for padding
mask = input_batch != pad_token_id
last_token_pos = mask.sum(dim=1) - 1 # Get position of last real token
with torch.no_grad():
logits = model(input_batch) # Logits of last output token
# Select the logits corresponding to the last real token of each sequence
batch_size = logits.size(0)
selected_logits = logits[torch.arange(batch_size), last_token_pos]
predicted_labels = torch.argmax(selected_logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
break
predicted_labels = torch.argmax(logits, dim=-1)
else:
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
model_output = model(input_batch)
if average_embeddings:
# Average over the sequence dimension (dim=1)
logits = model_output.mean(dim=1)
else:
# Select embeddings at the specified token position
logits = model_output[:, trainable_token_pos, :]
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
@@ -386,7 +428,7 @@ if __name__ == "__main__":
type=str,
default="last",
help=(
"Which token position to train. Options: 'first', 'last'."
"Which token position to train. Options: 'first', 'last', 'flexible'."
)
)
parser.add_argument(
@@ -483,6 +525,10 @@ if __name__ == "__main__":
args.trainable_token_pos = 0
elif args.trainable_token_pos == "last":
args.trainable_token_pos = -1
# The "flexible" setting selects the last tokens before the padding tokens
# See https://github.com/rasbt/LLMs-from-scratch/discussions/434
elif args.trainable_token_pos == "flexible":
args.trainable_token_pos = "flexible"
else:
raise ValueError("Invalid --trainable_token_pos argument")