Add defensive context trimming for multiturn (#815)

* Add defensive context trimming for multiturn

* add all mods
This commit is contained in:
Sebastian Raschka
2025-09-09 20:19:00 -05:00
committed by GitHub
parent 215abdbcdd
commit c7a4362ca4
5 changed files with 80 additions and 1 deletions

View File

@@ -0,0 +1,17 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# Additional utility and helper functions for text generation not covered
# in the main chapters
def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens):
assert max_new_tokens < context_len
keep_len = max(1, context_len - max_new_tokens)
# If the prompt is too long, left-truncate to keep_len
if input_ids_tensor.shape[1] > keep_len:
input_ids_tensor = input_ids_tensor[:, -keep_len:]
return input_ids_tensor

View File

@@ -3,6 +3,7 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..generate import trim_input_tensor # noqa: F401
from .utils import KVCache
import torch

View File

@@ -3,6 +3,7 @@
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from ..generate import trim_input_tensor # noqa: F401
from .utils import KVCache
import torch

View File

@@ -0,0 +1,54 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
from llms_from_scratch.ch02 import create_dataloader_v1
import os
import urllib.request
import pytest
import torch
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
def test_dataloader(tmp_path, file_name):
if not os.path.exists("the-verdict.txt"):
url = ("https://raw.githubusercontent.com/rasbt/"
"LLMs-from-scratch/main/ch02/01_main-chapter-code/"
"the-verdict.txt")
file_path = "the-verdict.txt"
urllib.request.urlretrieve(url, file_path)
with open("the-verdict.txt", "r", encoding="utf-8") as f:
raw_text = f.read()
vocab_size = 50257
output_dim = 256
context_length = 1024
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
batch_size = 8
max_length = 4
dataloader = create_dataloader_v1(
raw_text,
batch_size=batch_size,
max_length=max_length,
stride=max_length
)
for batch in dataloader:
x, y = batch
token_embeddings = token_embedding_layer(x)
pos_embeddings = pos_embedding_layer(torch.arange(max_length))
input_embeddings = token_embeddings + pos_embeddings
break
input_embeddings.shape == torch.Size([8, 4, 256])