add ignore index experiment

This commit is contained in:
rasbt
2024-05-19 07:24:49 -05:00
parent 5541f7c8fe
commit faffebae4b
2 changed files with 40 additions and 19 deletions

View File

@@ -164,14 +164,16 @@ def instantiate_model(choose_model, load_weights):
return model
def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token=-1, ignore_index=-100):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
def calc_loss_loader(data_loader, model, device,
num_batches=None, trainable_token=-1, ignore_index=-100):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
@@ -183,7 +185,10 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token=trainable_token, ignore_index=ignore_index
)
total_loss += loss.item()
else:
break
@@ -212,18 +217,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
def evaluate_model(model, train_loader, val_loader, device,
eval_iter, trainable_token=-1, ignore_index=-100):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
train_loss = calc_loss_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1,
accumulation_steps=1):
accumulation_steps=1, ignore_index=-100):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
@@ -233,7 +245,10 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
model.train() # Set model to training mode
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token=trainable_token, ignore_index=ignore_index
)
# Use gradient accumulation if accumulation_steps > 1
# See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
@@ -253,7 +268,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
model, train_loader, val_loader, device, eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
@@ -395,6 +412,15 @@ if __name__ == "__main__":
)
)
parser.add_argument(
"--ignore_index",
type=int,
default=-100,
help=(
"Sets the `ignore_index` in the cross entropy loss."
)
)
args = parser.parse_args()
if args.trainable_token == "first":