mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
simplify calc_loss_loader
This commit is contained in:
@@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
|
||||
|
||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||
total_loss, batches_seen = 0., 0.
|
||||
total_loss = 0.
|
||||
if num_batches is None:
|
||||
num_batches = len(data_loader)
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||
total_loss += loss.item()
|
||||
batches_seen += 1
|
||||
else:
|
||||
break
|
||||
return total_loss / batches_seen
|
||||
return total_loss / num_batches
|
||||
|
||||
|
||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||
|
||||
Reference in New Issue
Block a user