From d7c7393af7650c324ca434e024ca429f5fb4dc47 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 23 Jun 2025 08:21:55 -0500 Subject: [PATCH] Ch06 classifier function asserts (#703) --- ch06/01_main-chapter-code/ch06.ipynb | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ch06/01_main-chapter-code/ch06.ipynb b/ch06/01_main-chapter-code/ch06.ipynb index d0dabbd..d7b723a 100644 --- a/ch06/01_main-chapter-code/ch06.ipynb +++ b/ch06/01_main-chapter-code/ch06.ipynb @@ -2353,7 +2353,17 @@ "\n", " # Truncate sequences if they too long\n", " input_ids = input_ids[:min(max_length, supported_context_length)]\n", - "\n", + " assert max_length is not None, (\n", + " \"max_length must be specified. If you want to use the full model context, \"\n", + " \"pass max_length=model.pos_emb.weight.shape[0].\"\n", + " )\n", + " assert max_length <= supported_context_length, (\n", + " f\"max_length ({max_length}) exceeds model's supported context length ({supported_context_length}).\"\n", + " ) \n", + " # Alternatively, a more robust version is the following one, which handles the max_length=None case better\n", + " # max_len = min(max_length,supported_context_length) if max_length else supported_context_length\n", + " # input_ids = input_ids[:max_len]\n", + " \n", " # Pad sequences to the longest sequence\n", " input_ids += [pad_token_id] * (max_length - len(input_ids))\n", " input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension\n",