mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
simplify .view code
This commit is contained in:
@@ -250,10 +250,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
logits = model(input_batch)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
"id": "66dd524e-864c-4012-b0a2-ccfc56e80024"
|
||||
},
|
||||
"source": [
|
||||
"# Chapter 5: Pretraining on Unlabeled Data\n"
|
||||
"# Chapter 5: Pretraining on Unlabeled Data"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -260,7 +260,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -290,7 +290,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@@ -345,7 +345,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"id": "34ebd76a-16ec-4c17-8958-8a135735cc1c",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -385,23 +385,22 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "5e777e53-0a1d-4929-bcd5-b0ea2469bac6",
|
||||
"execution_count": 7,
|
||||
"id": "c990ead6-53cd-49a7-a6d1-14d8c1518249",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Armed heNetflix pressuring empoweredfaith'"
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Targets batch 1: effort moves you\n",
|
||||
"Outputs batch 1: Armed heNetflix\n"
|
||||
]
|
||||
},
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(token_ids_to_text(token_ids.flatten(), tokenizer))"
|
||||
"print(f\"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}\")\n",
|
||||
"print(f\"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -426,12 +425,12 @@
|
||||
"id": "c7251bf5-a079-4782-901d-68c9225d3157",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- In the first input batch, the token probabilities corresponding to the target indices are as follows:"
|
||||
"- The token probabilities corresponding to the target indices are as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 19,
|
||||
"id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -445,48 +444,19 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n"
|
||||
"Batch 1: tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n",
|
||||
"Batch 2: tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"batch_idx = 0\n",
|
||||
"target_probas_1 = probas[batch_idx, [0, 1, 2], targets[batch_idx]]\n",
|
||||
"print(target_probas_1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5305e2af-e973-4fc3-b717-a065a0f6ceec",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- And in the second input batch, the token probabilities corresponding to the target indices are as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d",
|
||||
"outputId": "5a439322-6e4f-49e9-b462-6d28b92415ab"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Batch 1:\", target_probas_1)\n",
|
||||
"\n",
|
||||
"batch_idx = 1\n",
|
||||
"target_probas_2 = probas[1, [0, 1, 2], targets[1]]\n",
|
||||
"print(target_probas_2)"
|
||||
"print(\"Batch 2:\", target_probas_2)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -500,7 +470,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 22,
|
||||
"id": "31402a67-a16e-4aeb-977e-70abb9c9949b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -534,7 +504,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 23,
|
||||
"id": "9b003797-161b-4d98-81dc-e68320e09fec",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@@ -573,7 +543,33 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- In deep learning, instead of maximizing the average log-probability, it's a standard convention to minimize the *negative* average log-probability value; in our case, instead of maximizing -10.7722 so that it approaches 0, in deep learning, we would minimize 10.7722 so that it approaches 0\n",
|
||||
"- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning\n",
|
||||
"- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"tensor(10.7722)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"neg_avg_log_probas = avg_log_probas * -1\n",
|
||||
"print(neg_avg_log_probas)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84eeb868-abd8-4028-82db-107546bf7c2c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- PyTorch already implements a `cross_entropy` function that carries out the previous steps"
|
||||
]
|
||||
},
|
||||
@@ -615,7 +611,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Logits have shape (batch_size, num_tokens, vocab_size\n",
|
||||
"# Logits have shape (batch_size, num_tokens, vocab_size)\n",
|
||||
"print(\"Logits shape:\", logits.shape)\n",
|
||||
"\n",
|
||||
"# Targets have shape (batch_size, num_tokens)\n",
|
||||
@@ -652,10 +648,10 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"logits_flat = logits.view(-1, logits.shape[-1])\n",
|
||||
"print(\"Flattened logits:\", logits_flat.shape)\n",
|
||||
"logits_flat = logits.flatten(0, 1)\n",
|
||||
"targets_flat = targets.flatten()\n",
|
||||
"\n",
|
||||
"targets_flat = targets.view(-1)\n",
|
||||
"print(\"Flattened logits:\", logits_flat.shape)\n",
|
||||
"print(\"Flattened targets:\", targets_flat.shape)"
|
||||
]
|
||||
},
|
||||
@@ -1114,8 +1110,8 @@
|
||||
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
||||
"\n",
|
||||
" logits = model(input_batch)\n",
|
||||
" logits = logits.view(-1, logits.size(-1))\n",
|
||||
" loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))\n",
|
||||
" logits = logits.flatten(0, 1)\n",
|
||||
" loss = torch.nn.functional.cross_entropy(logits, target_batch.flatten())\n",
|
||||
" return loss\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@@ -2496,7 +2492,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -25,10 +25,8 @@ def token_ids_to_text(token_ids, tokenizer):
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
logits = model(input_batch)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
@@ -244,10 +244,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
|
||||
|
||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
|
||||
|
||||
logits = model(input_batch)
|
||||
logits = logits.view(-1, logits.size(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
|
||||
loss = torch.nn.functional.cross_entropy(logits.flatten(0, -1), target_batch.flatten())
|
||||
return loss
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user