simplify .view code

This commit is contained in:
rasbt
2024-03-25 08:09:31 -05:00
parent d4989e01c5
commit de576296de
4 changed files with 60 additions and 70 deletions

View File

@@ -250,10 +250,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
logits = logits.view(-1, logits.size(-1))
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
return loss

View File

@@ -18,7 +18,7 @@
"id": "66dd524e-864c-4012-b0a2-ccfc56e80024"
},
"source": [
"# Chapter 5: Pretraining on Unlabeled Data\n"
"# Chapter 5: Pretraining on Unlabeled Data"
]
},
{
@@ -260,7 +260,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f",
"metadata": {
"colab": {
@@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6",
"metadata": {},
"outputs": [
@@ -345,7 +345,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "34ebd76a-16ec-4c17-8958-8a135735cc1c",
"metadata": {
"colab": {
@@ -385,23 +385,22 @@
},
{
"cell_type": "code",
"execution_count": 8,
"id": "5e777e53-0a1d-4929-bcd5-b0ea2469bac6",
"execution_count": 7,
"id": "c990ead6-53cd-49a7-a6d1-14d8c1518249",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Armed heNetflix pressuring empoweredfaith'"
"name": "stdout",
"output_type": "stream",
"text": [
"Targets batch 1: effort moves you\n",
"Outputs batch 1: Armed heNetflix\n"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(token_ids_to_text(token_ids.flatten(), tokenizer))"
"print(f\"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}\")\n",
"print(f\"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}\")"
]
},
{
@@ -426,12 +425,12 @@
"id": "c7251bf5-a079-4782-901d-68c9225d3157",
"metadata": {},
"source": [
"- In the first input batch, the token probabilities corresponding to the target indices are as follows:"
"- The token probabilities corresponding to the target indices are as follows:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 19,
"id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a",
"metadata": {
"colab": {
@@ -445,48 +444,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n"
"Batch 1: tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n",
"Batch 2: tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n"
]
}
],
"source": [
"batch_idx = 0\n",
"target_probas_1 = probas[batch_idx, [0, 1, 2], targets[batch_idx]]\n",
"print(target_probas_1)"
]
},
{
"cell_type": "markdown",
"id": "5305e2af-e973-4fc3-b717-a065a0f6ceec",
"metadata": {},
"source": [
"- And in the second input batch, the token probabilities corresponding to the target indices are as follows:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ea8ee24d-04f9-45d4-bef1-088c83cb8e0d",
"outputId": "5a439322-6e4f-49e9-b462-6d28b92415ab"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([3.9836e-05, 1.6783e-05, 4.7559e-06])\n"
]
}
],
"source": [
"print(\"Batch 1:\", target_probas_1)\n",
"\n",
"batch_idx = 1\n",
"target_probas_2 = probas[1, [0, 1, 2], targets[1]]\n",
"print(target_probas_2)"
"print(\"Batch 2:\", target_probas_2)"
]
},
{
@@ -500,7 +470,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 22,
"id": "31402a67-a16e-4aeb-977e-70abb9c9949b",
"metadata": {
"colab": {
@@ -534,7 +504,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 23,
"id": "9b003797-161b-4d98-81dc-e68320e09fec",
"metadata": {
"colab": {
@@ -573,7 +543,33 @@
"metadata": {},
"source": [
"- In deep learning, instead of maximizing the average log-probability, it's a standard convention to minimize the *negative* average log-probability value; in our case, instead of maximizing -10.7722 so that it approaches 0, in deep learning, we would minimize 10.7722 so that it approaches 0\n",
"- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning\n",
"- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(10.7722)\n"
]
}
],
"source": [
"neg_avg_log_probas = avg_log_probas * -1\n",
"print(neg_avg_log_probas)"
]
},
{
"cell_type": "markdown",
"id": "84eeb868-abd8-4028-82db-107546bf7c2c",
"metadata": {},
"source": [
"- PyTorch already implements a `cross_entropy` function that carries out the previous steps"
]
},
@@ -615,7 +611,7 @@
}
],
"source": [
"# Logits have shape (batch_size, num_tokens, vocab_size\n",
"# Logits have shape (batch_size, num_tokens, vocab_size)\n",
"print(\"Logits shape:\", logits.shape)\n",
"\n",
"# Targets have shape (batch_size, num_tokens)\n",
@@ -652,10 +648,10 @@
}
],
"source": [
"logits_flat = logits.view(-1, logits.shape[-1])\n",
"print(\"Flattened logits:\", logits_flat.shape)\n",
"logits_flat = logits.flatten(0, 1)\n",
"targets_flat = targets.flatten()\n",
"\n",
"targets_flat = targets.view(-1)\n",
"print(\"Flattened logits:\", logits_flat.shape)\n",
"print(\"Flattened targets:\", targets_flat.shape)"
]
},
@@ -1114,8 +1110,8 @@
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
"\n",
" logits = model(input_batch)\n",
" logits = logits.view(-1, logits.size(-1))\n",
" loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))\n",
" logits = logits.flatten(0, 1)\n",
" loss = torch.nn.functional.cross_entropy(logits, target_batch.flatten())\n",
" return loss\n",
"\n",
"\n",
@@ -2496,7 +2492,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -25,10 +25,8 @@ def token_ids_to_text(token_ids, tokenizer):
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
logits = logits.view(-1, logits.size(-1))
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
return loss

View File

@@ -244,10 +244,8 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)
logits = logits.view(-1, logits.size(-1))
loss = torch.nn.functional.cross_entropy(logits, target_batch.view(-1))
loss = torch.nn.functional.cross_entropy(logits.flatten(0, -1), target_batch.flatten())
return loss