Return nan if val loader is empty (#124)

This commit is contained in:
Sebastian Raschka
2024-04-20 08:02:30 -05:00
committed by GitHub
parent b5878a80ff
commit 4557d5830e
5 changed files with 15 additions and 5 deletions

View File

@@ -255,7 +255,9 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))