ouput -> output

This commit is contained in:
rasbt
2024-05-05 12:21:10 -05:00
parent d361cef65f
commit 9457676640
4 changed files with 8 additions and 8 deletions

View File

@@ -466,7 +466,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last ouput token
logits = model(input_batch)[:, -1, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
@@ -478,7 +478,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None):
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, -1, :] # Logits of last ouput token
logits = model(input_batch)[:, -1, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss