simplify calc_loss_loader

This commit is contained in:
rasbt
2024-03-26 20:34:50 -05:00
parent c88e8edf72
commit 3cb5a52a1b
5 changed files with 33 additions and 71 deletions

View File

@@ -256,17 +256,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss, batches_seen = 0., 0.
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
batches_seen += 1
else:
break
return total_loss / batches_seen
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):

View File

@@ -764,7 +764,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 23,
"id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
"metadata": {},
"outputs": [],
@@ -785,43 +785,6 @@
" text_data = file.read()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "0959c855-f860-4358-8b98-bc654f047578",
"metadata": {},
"outputs": [],
"source": [
"from previous_chapters import create_dataloader_v1\n",
"\n",
"# Train/validation ratio\n",
"train_ratio = 0.90\n",
"split_idx = int(train_ratio * len(text_data))\n",
"train_data = text_data[:split_idx]\n",
"val_data = text_data[split_idx:]\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = create_dataloader_v1(\n",
" train_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
")\n",
"\n",
"val_loader = create_dataloader_v1(\n",
" val_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
")"
]
},
{
"cell_type": "markdown",
"id": "379330f1-80f4-4e34-8724-41d892b04cee",
@@ -832,7 +795,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 24,
"id": "6kgJbe4ehI4q",
"metadata": {
"colab": {
@@ -858,7 +821,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 25,
"id": "j2XPde_ThM_e",
"metadata": {
"colab": {
@@ -884,7 +847,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 27,
"id": "6b46a952-d50a-4837-af09-4095698f7fd1",
"metadata": {
"colab": {
@@ -940,22 +903,24 @@
},
{
"cell_type": "code",
"execution_count": 22,
"id": "fd0e963b-d282-4c97-b004-8772f4b1bd8f",
"execution_count": 35,
"id": "0959c855-f860-4358-8b98-bc654f047578",
"metadata": {},
"outputs": [],
"source": [
"from previous_chapters import create_dataloader_v1\n",
"\n",
"\n",
"# Train/validation ratio\n",
"train_ratio = 0.90\n",
"split_idx = int(train_ratio * len(text_data))\n",
"train_data = text_data[:split_idx]\n",
"val_data = text_data[split_idx:]\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = create_dataloader_v1(\n",
" text_data[:split_idx],\n",
" train_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
@@ -964,7 +929,7 @@
")\n",
"\n",
"val_loader = create_dataloader_v1(\n",
" text_data[split_idx:],\n",
" val_data,\n",
" batch_size=2,\n",
" max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
" stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
@@ -975,7 +940,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 36,
"id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
"metadata": {},
"outputs": [],
@@ -1012,7 +977,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 37,
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
"metadata": {},
"outputs": [
@@ -1056,7 +1021,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 38,
"id": "eb860488-5453-41d7-9870-23b723f742a0",
"metadata": {
"colab": {
@@ -1101,7 +1066,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 49,
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
"metadata": {
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
@@ -1118,17 +1083,16 @@
"\n",
"\n",
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
" total_loss, batches_seen = 0., 0.\n",
" total_loss = 0.\n",
" if num_batches is None:\n",
" num_batches = len(data_loader)\n",
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n",
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
" total_loss += loss.item()\n",
" batches_seen += 1\n",
" else:\n",
" break\n",
" return total_loss / batches_seen"
" return total_loss / num_batches"
]
},
{
@@ -1142,7 +1106,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 55,
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
"metadata": {},
"outputs": [
@@ -1150,7 +1114,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training loss: 11.002305030822754\n",
"Training loss: 10.98758347829183\n",
"Validation loss: 10.98110580444336\n"
]
}
@@ -1159,6 +1123,8 @@
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
"\n",
"\n",
"torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
"train_loss = calc_loss_loader(train_loader, model, device, num_batches=1)\n",
"val_loss = calc_loss_loader(val_loader, model, device, num_batches=1)\n",
"\n",

View File

@@ -31,17 +31,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss, batches_seen = 0., 0.
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
batches_seen += 1
else:
break
return total_loss / batches_seen
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):

View File

@@ -250,17 +250,16 @@ def calc_loss_batch(input_batch, target_batch, model, device):
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss, batches_seen = 0., 0.
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
batches_seen += 1
else:
break
return total_loss / batches_seen
return total_loss / num_batches
def evaluate_model(model, train_loader, val_loader, device, eval_iter):

View File

@@ -23,18 +23,17 @@ HPARAM_GRID = {
}
def calc_loss_loader(data_loader, model, device, num_iters=None):
total_loss, num_batches = 0., 0
if num_iters is None:
num_iters = len(data_loader)
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_iters:
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
num_batches += 1
else:
break
return total_loss
return total_loss / num_batches
def calc_loss_batch(input_batch, target_batch, model, device):