simplify .view code

This commit is contained in:
rasbt
2024-03-25 08:09:31 -05:00
parent d4989e01c5
commit de576296de
4 changed files with 60 additions and 70 deletions

View File

@@ -250,10 +250,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
logits = logits.view(-1, logits.size(-1))
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
return loss