Fix Loss in Gutenberg bonus section (#109)

This commit is contained in:
Sebastian Raschka
2024-04-04 20:54:09 -05:00
committed by GitHub
parent c8cffefb6f
commit 25f533efe0
4 changed files with 58 additions and 16 deletions

View File

@@ -1081,10 +1081,8 @@
"source": [
"def calc_loss_batch(input_batch, target_batch, model, device):\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
"\n",
" logits = model(input_batch)\n",
" logits = logits.flatten(0, 1)\n",
" loss = torch.nn.functional.cross_entropy(logits, target_batch.flatten())\n",
" loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())\n",
" return loss\n",
"\n",
"\n",
@@ -2403,7 +2401,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,