Ch06 classifier function asserts (#703)

This commit is contained in:
Sebastian Raschka
2025-06-23 08:21:55 -05:00
committed by GitHub
parent f5bc863752
commit 4014bdd520

View File

@@ -2353,6 +2353,16 @@
"\n", "\n",
" # Truncate sequences if they too long\n", " # Truncate sequences if they too long\n",
" input_ids = input_ids[:min(max_length, supported_context_length)]\n", " input_ids = input_ids[:min(max_length, supported_context_length)]\n",
" assert max_length is not None, (\n",
" \"max_length must be specified. If you want to use the full model context, \"\n",
" \"pass max_length=model.pos_emb.weight.shape[0].\"\n",
" )\n",
" assert max_length <= supported_context_length, (\n",
" f\"max_length ({max_length}) exceeds model's supported context length ({supported_context_length}).\"\n",
" ) \n",
" # Alternatively, a more robust version is the following one, which handles the max_length=None case better\n",
" # max_len = min(max_length,supported_context_length) if max_length else supported_context_length\n",
" # input_ids = input_ids[:max_len]\n",
" \n", " \n",
" # Pad sequences to the longest sequence\n", " # Pad sequences to the longest sequence\n",
" input_ids += [pad_token_id] * (max_length - len(input_ids))\n", " input_ids += [pad_token_id] * (max_length - len(input_ids))\n",