make batch loss calculatution more efficient

This commit is contained in:
rasbt
2024-03-27 07:11:56 -05:00
parent 3cb5a52a1b
commit 88b2dd780a
5 changed files with 79 additions and 73 deletions

View File

@@ -259,6 +259,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)