add assertion about data set length

This commit is contained in:
rasbt
2024-05-23 06:50:43 -05:00
parent ec70194d19
commit 18e729643d
3 changed files with 20 additions and 2 deletions

View File

@@ -837,7 +837,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 1,
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
"metadata": {
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
@@ -861,7 +861,13 @@
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"assert train_dataset.max_length <= BASE_CONFIG[\"context_length\"], (\n",
" f\"Dataset length {train_dataset.max_length} exceeds model's context \"\n",
" f\"length {BASE_CONFIG['context_length']}. Reinitialize data sets with \"\n",
" f\"`max_length={BASE_CONFIG['context_length']}`\"\n",
")"
]
},
{