mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
make batch loss calculatution more efficient
This commit is contained in:
@@ -27,6 +27,8 @@ def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||
total_loss = 0.
|
||||
if num_batches is None:
|
||||
num_batches = len(data_loader)
|
||||
else:
|
||||
num_batches = min(num_batches, len(data_loader))
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||
|
||||
Reference in New Issue
Block a user