mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
add ignore index experiment
This commit is contained in:
@@ -164,14 +164,16 @@ def instantiate_model(choose_model, load_weights):
|
||||
return model
|
||||
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
|
||||
def calc_loss_batch(input_batch, target_batch, model, device,
|
||||
trainable_token=-1, ignore_index=-100):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
|
||||
return loss
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
|
||||
def calc_loss_loader(data_loader, model, device,
|
||||
num_batches=None, trainable_token=-1, ignore_index=-100):
|
||||
total_loss = 0.
|
||||
if len(data_loader) == 0:
|
||||
return float("nan")
|
||||
@@ -183,7 +185,10 @@ def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_tok
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token=trainable_token, ignore_index=ignore_index
|
||||
)
|
||||
total_loss += loss.item()
|
||||
else:
|
||||
break
|
||||
@@ -212,18 +217,25 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
|
||||
return correct_predictions / num_examples
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
|
||||
def evaluate_model(model, train_loader, val_loader, device,
|
||||
eval_iter, trainable_token=-1, ignore_index=-100):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
|
||||
train_loss = calc_loss_loader(
|
||||
train_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token=trainable_token, ignore_index=ignore_index
|
||||
)
|
||||
val_loss = calc_loss_loader(
|
||||
val_loader, model, device, num_batches=eval_iter,
|
||||
trainable_token=trainable_token, ignore_index=ignore_index
|
||||
)
|
||||
model.train()
|
||||
return train_loss, val_loss
|
||||
|
||||
|
||||
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1,
|
||||
accumulation_steps=1):
|
||||
accumulation_steps=1, ignore_index=-100):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, train_accs, val_accs = [], [], [], []
|
||||
examples_seen, global_step = 0, -1
|
||||
@@ -233,7 +245,10 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
model.train() # Set model to training mode
|
||||
|
||||
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
|
||||
loss = calc_loss_batch(
|
||||
input_batch, target_batch, model, device,
|
||||
trainable_token=trainable_token, ignore_index=ignore_index
|
||||
)
|
||||
|
||||
# Use gradient accumulation if accumulation_steps > 1
|
||||
# See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html
|
||||
@@ -253,7 +268,9 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
# Optional evaluation step
|
||||
if global_step % eval_freq == 0:
|
||||
train_loss, val_loss = evaluate_model(
|
||||
model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
|
||||
model, train_loader, val_loader, device, eval_iter,
|
||||
trainable_token=trainable_token, ignore_index=ignore_index
|
||||
)
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
print(f"Ep {epoch+1} (Step {global_step:06d}): "
|
||||
@@ -395,6 +412,15 @@ if __name__ == "__main__":
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ignore_index",
|
||||
type=int,
|
||||
default=-100,
|
||||
help=(
|
||||
"Sets the `ignore_index` in the cross entropy loss."
|
||||
)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token == "first":
|
||||
|
||||
Reference in New Issue
Block a user