add check for small validation sets

This commit is contained in:
rasbt
2024-03-19 06:34:52 -05:00
parent ca96abac8a
commit 861a2788f3

View File

@@ -20,8 +20,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"numpy version: 1.26.0\n",
"matplotlib version: 3.8.2\n",
"numpy version: 1.25.2\n",
"matplotlib version: 3.7.2\n",
"numpy version: 1.25.2\n",
"tiktoken version: 0.5.1\n",
"torch version: 2.2.1\n"
]
@@ -916,8 +917,11 @@
}
],
"source": [
"print(\"Characters:\", len(text_data))\n",
"print(\"Tokens:\", len(tokenizer.encode(text_data)))"
"total_char = len(text_data)\n",
"total_tokens = len(tokenizer.encode(text_data))\n",
"\n",
"print(\"Characters:\", total_char)\n",
"print(\"Tokens:\", total_tokens)"
]
},
{
@@ -962,7 +966,6 @@
"train_ratio = 0.90\n",
"split_idx = int(train_ratio * len(text_data))\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = create_dataloader_v1(\n",
@@ -984,6 +987,26 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
"metadata": {},
"outputs": [],
"source": [
"# Sanity check\n",
"\n",
"if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
" print(\"Not enough tokens for the training loader. \"\n",
" \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
" \"increase the `training_ratio`\")\n",
"\n",
"if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"ctx_len\"]:\n",
" print(\"Not enough tokens for the validation loader. \"\n",
" \"Try to lower the `GPT_CONFIG_124M['ctx_len']` or \"\n",
" \"decrease the `training_ratio`\")"
]
},
{
"cell_type": "markdown",
"id": "e7ac3296-a4d1-4303-9ac5-376518960c33",
@@ -1003,7 +1026,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
"metadata": {},
"outputs": [
@@ -1047,7 +1070,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"id": "eb860488-5453-41d7-9870-23b723f742a0",
"metadata": {
"colab": {
@@ -1092,7 +1115,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
"metadata": {
"id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
@@ -1133,7 +1156,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
"metadata": {},
"outputs": [
@@ -1195,7 +1218,7 @@
},
"outputs": [],
"source": [
"def train_model_simple(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
"def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n",
" eval_freq, eval_iter, start_context):\n",
" # Initialize lists to track losses and tokens seen\n",
" train_losses, val_losses, track_tokens_seen = [], [], []\n",
@@ -1203,7 +1226,7 @@
" global_step = -1\n",
"\n",
" # Main training loop\n",
" for epoch in range(n_epochs):\n",
" for epoch in range(num_epochs):\n",
" model.train() # Set model to training mode\n",
" \n",
" for input_batch, target_batch in train_loader:\n",
@@ -1246,8 +1269,10 @@
" context_size = model.pos_emb.weight.shape[0]\n",
" encoded = text_to_token_ids(start_context, tokenizer).to(device)\n",
" with torch.no_grad():\n",
" token_ids = generate_text_simple(model=model, idx=encoded,\n",
" max_new_tokens=50, context_size=context_size)\n",
" token_ids = generate_text_simple(\n",
" model=model, idx=encoded,\n",
" max_new_tokens=50, context_size=context_size\n",
" )\n",
" decoded_text = token_ids_to_text(token_ids, tokenizer)\n",
" print(decoded_text.replace(\"\\n\", \" \")) # Compact print format\n",
" model.train()"
@@ -1314,10 +1339,10 @@
"model.to(device)\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.1)\n",
"\n",
"n_epochs = 10\n",
"num_epochs = 10\n",
"train_losses, val_losses, tokens_seen = train_model_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" n_epochs=n_epochs, eval_freq=5, eval_iter=1,\n",
" num_epochs=num_epochs, eval_freq=5, eval_iter=1,\n",
" start_context=\"Every effort moves you\",\n",
")"
]
@@ -1368,7 +1393,7 @@
" plt.show()\n",
"\n",
"\n",
"epochs_tensor = torch.linspace(0, n_epochs, len(train_losses))\n",
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
]
},
@@ -1959,7 +1984,7 @@
"metadata": {},
"outputs": [],
"source": [
"torch.save(model.state_dict(), 'model.pth')"
"torch.save(model.state_dict(), \"model.pth\")"
]
},
{
@@ -1978,7 +2003,7 @@
"outputs": [],
"source": [
"model = GPTModel(GPT_CONFIG_124M)\n",
"model.load_state_dict(torch.load('model.pth'))\n",
"model.load_state_dict(torch.load(\"model.pth\"))\n",
"model.eval();"
]
},
@@ -1999,10 +2024,10 @@
"outputs": [],
"source": [
"torch.save({\n",
" 'model_state_dict': model.state_dict(),\n",
" 'optimizer_state_dict': optimizer.state_dict(),\n",
" \"model_state_dict\": model.state_dict(),\n",
" \"optimizer_state_dict\": optimizer.state_dict(),\n",
" }, \n",
" 'model_and_optimizer.pth'\n",
" \"model_and_optimizer.pth\"\n",
")"
]
},
@@ -2013,9 +2038,9 @@
"metadata": {},
"outputs": [],
"source": [
"checkpoint = torch.load('model_and_optimizer.pth')\n",
"model.load_state_dict(checkpoint['model_state_dict'])\n",
"optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
"checkpoint = torch.load(\"model_and_optimizer.pth\")\n",
"model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
"optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
"model.train();"
]
},
@@ -2474,7 +2499,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,