mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
simplify .view code
This commit is contained in:
@@ -250,10 +250,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
logits = model(input_batch)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user