mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
ouput -> output
This commit is contained in:
@@ -466,7 +466,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||
if i < num_batches:
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, -1, :] # Logits of last ouput token
|
||||
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
predicted_labels = torch.argmax(logits, dim=-1)
|
||||
|
||||
num_examples += predicted_labels.shape[0]
|
||||
@@ -478,7 +478,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
logits = model(input_batch)[:, -1, :] # Logits of last ouput token
|
||||
logits = model(input_batch)[:, -1, :] # Logits of last output token
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch)
|
||||
return loss
|
||||
|
||||
|
||||
Reference in New Issue
Block a user