simplify calc_loss_loader

This commit is contained in:
rasbt
2024-03-26 20:34:50 -05:00
parent c88e8edf72
commit 3cb5a52a1b
5 changed files with 33 additions and 71 deletions

View File

@@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss, batches_seen = 0., 0.
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
batches_seen += 1
else:
break
return total_loss / batches_seen
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):