mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
simplify calc_loss_loader
This commit is contained in:
@@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
|
|||||||
|
|
||||||
|
|
||||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||||
total_loss, batches_seen = 0., 0.
|
total_loss = 0.
|
||||||
if num_batches is None:
|
if num_batches is None:
|
||||||
num_batches = len(data_loader)
|
num_batches = len(data_loader)
|
||||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||||
if i < num_batches:
|
if i < num_batches:
|
||||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
batches_seen += 1
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return total_loss / batches_seen
|
return total_loss / num_batches
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||||
|
|||||||
@@ -764,7 +764,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 23,
|
||||||
"id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
|
"id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -785,43 +785,6 @@
|
|||||||
" text_data = file.read()"
|
" text_data = file.read()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 18,
|
|
||||||
"id": "0959c855-f860-4358-8b98-bc654f047578",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"from previous_chapters import create_dataloader_v1\n",
|
|
||||||
"\n",
|
|
||||||
"# Train/validation ratio\n",
|
|
||||||
"train_ratio = 0.90\n",
|
|
||||||
"split_idx = int(train_ratio * len(text_data))\n",
|
|
||||||
"train_data = text_data[:split_idx]\n",
|
|
||||||
"val_data = text_data[split_idx:]\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"\n",
|
|
||||||
"train_loader = create_dataloader_v1(\n",
|
|
||||||
" train_data,\n",
|
|
||||||
" batch_size=2,\n",
|
|
||||||
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
||||||
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
||||||
" drop_last=True,\n",
|
|
||||||
" shuffle=True\n",
|
|
||||||
")\n",
|
|
||||||
"\n",
|
|
||||||
"val_loader = create_dataloader_v1(\n",
|
|
||||||
" val_data,\n",
|
|
||||||
" batch_size=2,\n",
|
|
||||||
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
||||||
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
|
||||||
" drop_last=False,\n",
|
|
||||||
" shuffle=False\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "379330f1-80f4-4e34-8724-41d892b04cee",
|
"id": "379330f1-80f4-4e34-8724-41d892b04cee",
|
||||||
@@ -832,7 +795,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 24,
|
||||||
"id": "6kgJbe4ehI4q",
|
"id": "6kgJbe4ehI4q",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -858,7 +821,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 25,
|
||||||
"id": "j2XPde_ThM_e",
|
"id": "j2XPde_ThM_e",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -884,7 +847,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 27,
|
||||||
"id": "6b46a952-d50a-4837-af09-4095698f7fd1",
|
"id": "6b46a952-d50a-4837-af09-4095698f7fd1",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -940,22 +903,24 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 35,
|
||||||
"id": "fd0e963b-d282-4c97-b004-8772f4b1bd8f",
|
"id": "0959c855-f860-4358-8b98-bc654f047578",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from previous_chapters import create_dataloader_v1\n",
|
"from previous_chapters import create_dataloader_v1\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
|
||||||
"# Train/validation ratio\n",
|
"# Train/validation ratio\n",
|
||||||
"train_ratio = 0.90\n",
|
"train_ratio = 0.90\n",
|
||||||
"split_idx = int(train_ratio * len(text_data))\n",
|
"split_idx = int(train_ratio * len(text_data))\n",
|
||||||
|
"train_data = text_data[:split_idx]\n",
|
||||||
|
"val_data = text_data[split_idx:]\n",
|
||||||
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"torch.manual_seed(123)\n",
|
"torch.manual_seed(123)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_loader = create_dataloader_v1(\n",
|
"train_loader = create_dataloader_v1(\n",
|
||||||
" text_data[:split_idx],\n",
|
" train_data,\n",
|
||||||
" batch_size=2,\n",
|
" batch_size=2,\n",
|
||||||
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
||||||
@@ -964,7 +929,7 @@
|
|||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"val_loader = create_dataloader_v1(\n",
|
"val_loader = create_dataloader_v1(\n",
|
||||||
" text_data[split_idx:],\n",
|
" val_data,\n",
|
||||||
" batch_size=2,\n",
|
" batch_size=2,\n",
|
||||||
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
||||||
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
|
||||||
@@ -975,7 +940,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 23,
|
"execution_count": 36,
|
||||||
"id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
|
"id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1012,7 +977,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 37,
|
||||||
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
|
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1056,7 +1021,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": 38,
|
||||||
"id": "eb860488-5453-41d7-9870-23b723f742a0",
|
"id": "eb860488-5453-41d7-9870-23b723f742a0",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@@ -1101,7 +1066,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": 49,
|
||||||
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
|
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
|
||||||
@@ -1118,17 +1083,16 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
|
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
|
||||||
" total_loss, batches_seen = 0., 0.\n",
|
" total_loss = 0.\n",
|
||||||
" if num_batches is None:\n",
|
" if num_batches is None:\n",
|
||||||
" num_batches = len(data_loader)\n",
|
" num_batches = len(data_loader)\n",
|
||||||
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
||||||
" if i < num_batches:\n",
|
" if i < num_batches:\n",
|
||||||
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
|
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
|
||||||
" total_loss += loss.item()\n",
|
" total_loss += loss.item()\n",
|
||||||
" batches_seen += 1\n",
|
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" break\n",
|
" break\n",
|
||||||
" return total_loss / batches_seen"
|
" return total_loss / num_batches"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1142,7 +1106,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 55,
|
||||||
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
|
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1150,7 +1114,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Training loss: 11.002305030822754\n",
|
"Training loss: 10.98758347829183\n",
|
||||||
"Validation loss: 10.98110580444336\n"
|
"Validation loss: 10.98110580444336\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -1159,6 +1123,8 @@
|
|||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
|
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
|
||||||
"train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n",
|
"train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n",
|
||||||
"val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n",
|
"val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@@ -31,17 +31,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
|
|||||||
|
|
||||||
|
|
||||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||||
total_loss, batches_seen = 0., 0.
|
total_loss = 0.
|
||||||
if num_batches is None:
|
if num_batches is None:
|
||||||
num_batches = len(data_loader)
|
num_batches = len(data_loader)
|
||||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||||
if i < num_batches:
|
if i < num_batches:
|
||||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
batches_seen += 1
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return total_loss / batches_seen
|
return total_loss / num_batches
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||||
|
|||||||
@@ -250,17 +250,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
|
|||||||
|
|
||||||
|
|
||||||
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||||
total_loss, batches_seen = 0., 0.
|
total_loss = 0.
|
||||||
if num_batches is None:
|
if num_batches is None:
|
||||||
num_batches = len(data_loader)
|
num_batches = len(data_loader)
|
||||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||||
if i < num_batches:
|
if i < num_batches:
|
||||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
batches_seen += 1
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return total_loss / batches_seen
|
return total_loss / num_batches
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||||
|
|||||||
@@ -23,18 +23,17 @@ HPARAM_GRID = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def calc_loss_loader(data_loader, model, device, num_iters=None):
|
def calc_loss_loader(data_loader, model, device, num_batches=None):
|
||||||
total_loss, num_batches = 0., 0
|
total_loss = 0.
|
||||||
if num_iters is None:
|
if num_batches is None:
|
||||||
num_iters = len(data_loader)
|
num_batches = len(data_loader)
|
||||||
for i, (input_batch, target_batch) in enumerate(data_loader):
|
for i, (input_batch, target_batch) in enumerate(data_loader):
|
||||||
if i < num_iters:
|
if i < num_batches:
|
||||||
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
loss = calc_loss_batch(input_batch, target_batch, model, device)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
num_batches += 1
|
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return total_loss
|
return total_loss / num_batches
|
||||||
|
|
||||||
|
|
||||||
def calc_loss_batch(input_batch, target_batch, model, device):
|
def calc_loss_batch(input_batch, target_batch, model, device):
|
||||||
|
|||||||
Reference in New Issue
Block a user