From c7a4362ca4c38b06ed9aa8b802b424db40741038 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Tue, 9 Sep 2025 20:19:00 -0500 Subject: [PATCH] Add defensive context trimming for multiturn (#815) * Add defensive context trimming for multiturn * add all mods --- .../qwen3-chat-interface-multiturn.py | 8 ++- pkg/llms_from_scratch/generate.py | 17 ++++++ pkg/llms_from_scratch/kv_cache/generate.py | 1 + .../kv_cache_batched/generate.py | 1 + pkg/llms_from_scratch/tests/test_generate.py | 54 +++++++++++++++++++ 5 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 pkg/llms_from_scratch/generate.py create mode 100644 pkg/llms_from_scratch/tests/test_generate.py diff --git a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py index 37a4366..b9b067a 100644 --- a/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py +++ b/ch05/11_qwen3/qwen3-chat-interface/qwen3-chat-interface-multiturn.py @@ -15,7 +15,8 @@ from llms_from_scratch.kv_cache.qwen3 import ( load_weights_into_qwen ) from llms_from_scratch.kv_cache.generate import ( - generate_text_simple_stream + generate_text_simple_stream, + trim_input_tensor ) # ============================================================ @@ -141,6 +142,11 @@ async def main(message: chainlit.Message): prompt = build_prompt_from_history(history, add_assistant_header=True) input_ids = TOKENIZER.encode(prompt) input_ids_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0) + input_ids_tensor = trim_input_tensor( + input_ids_tensor=input_ids_tensor, + context_len=MODEL.cfg["context_length"], + max_new_tokens=MAX_NEW_TOKENS + ) # 2) Start an outgoing message we can stream into out_msg = chainlit.Message(content="") diff --git a/pkg/llms_from_scratch/generate.py b/pkg/llms_from_scratch/generate.py new file mode 100644 index 0000000..d7c741f --- /dev/null +++ b/pkg/llms_from_scratch/generate.py @@ -0,0 +1,17 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +# Additional utility and helper functions for text generation not covered +# in the main chapters + +def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens): + assert max_new_tokens < context_len + keep_len = max(1, context_len - max_new_tokens) + + # If the prompt is too long, left-truncate to keep_len + if input_ids_tensor.shape[1] > keep_len: + input_ids_tensor = input_ids_tensor[:, -keep_len:] + + return input_ids_tensor diff --git a/pkg/llms_from_scratch/kv_cache/generate.py b/pkg/llms_from_scratch/kv_cache/generate.py index 0121986..1502a0f 100644 --- a/pkg/llms_from_scratch/kv_cache/generate.py +++ b/pkg/llms_from_scratch/kv_cache/generate.py @@ -3,6 +3,7 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch +from ..generate import trim_input_tensor # noqa: F401 from .utils import KVCache import torch diff --git a/pkg/llms_from_scratch/kv_cache_batched/generate.py b/pkg/llms_from_scratch/kv_cache_batched/generate.py index b3bf88b..9e8b0a4 100644 --- a/pkg/llms_from_scratch/kv_cache_batched/generate.py +++ b/pkg/llms_from_scratch/kv_cache_batched/generate.py @@ -3,6 +3,7 @@ # - https://www.manning.com/books/build-a-large-language-model-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch +from ..generate import trim_input_tensor # noqa: F401 from .utils import KVCache import torch diff --git a/pkg/llms_from_scratch/tests/test_generate.py b/pkg/llms_from_scratch/tests/test_generate.py new file mode 100644 index 0000000..11d8a52 --- /dev/null +++ b/pkg/llms_from_scratch/tests/test_generate.py @@ -0,0 +1,54 @@ +# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). +# Source for "Build a Large Language Model From Scratch" +# - https://www.manning.com/books/build-a-large-language-model-from-scratch +# Code: https://github.com/rasbt/LLMs-from-scratch + +from llms_from_scratch.ch02 import create_dataloader_v1 + +import os +import urllib.request + +import pytest +import torch + + +@pytest.mark.parametrize("file_name", ["the-verdict.txt"]) +def test_dataloader(tmp_path, file_name): + + if not os.path.exists("the-verdict.txt"): + url = ("https://raw.githubusercontent.com/rasbt/" + "LLMs-from-scratch/main/ch02/01_main-chapter-code/" + "the-verdict.txt") + file_path = "the-verdict.txt" + urllib.request.urlretrieve(url, file_path) + + with open("the-verdict.txt", "r", encoding="utf-8") as f: + raw_text = f.read() + + vocab_size = 50257 + output_dim = 256 + context_length = 1024 + + token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim) + pos_embedding_layer = torch.nn.Embedding(context_length, output_dim) + + batch_size = 8 + max_length = 4 + dataloader = create_dataloader_v1( + raw_text, + batch_size=batch_size, + max_length=max_length, + stride=max_length + ) + + for batch in dataloader: + x, y = batch + + token_embeddings = token_embedding_layer(x) + pos_embeddings = pos_embedding_layer(torch.arange(max_length)) + + input_embeddings = token_embeddings + pos_embeddings + + break + + input_embeddings.shape == torch.Size([8, 4, 256])