Fix truncation issue in classify_review function (#373)

This commit is contained in:
Sebastian Raschka
2024-09-25 19:54:36 -05:00
committed by GitHub
parent b56d0b2942
commit 7ef5129e18
3 changed files with 5 additions and 3 deletions

View File

@@ -2207,7 +2207,9 @@
"\n",
" # Prepare inputs to the model\n",
" input_ids = tokenizer.encode(text)\n",
" supported_context_length = model.pos_emb.weight.shape[1]\n",
" supported_context_length = model.pos_emb.weight.shape[0]\n",
" # Note: In the book, this was originally written as pos_emb.weight.shape[1] by mistake\n",
" # It didn't break the code but would have caused unnecessary truncation (to 768 instead of 1024)\n",
"\n",
" # Truncate sequences if they too long\n",
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",