mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add defensive context trimming for multiturn (#815)
* Add defensive context trimming for multiturn * add all mods
This commit is contained in:
committed by
GitHub
parent
215abdbcdd
commit
c7a4362ca4
@@ -15,7 +15,8 @@ from llms_from_scratch.kv_cache.qwen3 import (
|
|||||||
load_weights_into_qwen
|
load_weights_into_qwen
|
||||||
)
|
)
|
||||||
from llms_from_scratch.kv_cache.generate import (
|
from llms_from_scratch.kv_cache.generate import (
|
||||||
generate_text_simple_stream
|
generate_text_simple_stream,
|
||||||
|
trim_input_tensor
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -141,6 +142,11 @@ async def main(message: chainlit.Message):
|
|||||||
prompt = build_prompt_from_history(history, add_assistant_header=True)
|
prompt = build_prompt_from_history(history, add_assistant_header=True)
|
||||||
input_ids = TOKENIZER.encode(prompt)
|
input_ids = TOKENIZER.encode(prompt)
|
||||||
input_ids_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)
|
input_ids_tensor = torch.tensor(input_ids, device=DEVICE).unsqueeze(0)
|
||||||
|
input_ids_tensor = trim_input_tensor(
|
||||||
|
input_ids_tensor=input_ids_tensor,
|
||||||
|
context_len=MODEL.cfg["context_length"],
|
||||||
|
max_new_tokens=MAX_NEW_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
# 2) Start an outgoing message we can stream into
|
# 2) Start an outgoing message we can stream into
|
||||||
out_msg = chainlit.Message(content="")
|
out_msg = chainlit.Message(content="")
|
||||||
|
|||||||
17
pkg/llms_from_scratch/generate.py
Normal file
17
pkg/llms_from_scratch/generate.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
# Additional utility and helper functions for text generation not covered
|
||||||
|
# in the main chapters
|
||||||
|
|
||||||
|
def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens):
|
||||||
|
assert max_new_tokens < context_len
|
||||||
|
keep_len = max(1, context_len - max_new_tokens)
|
||||||
|
|
||||||
|
# If the prompt is too long, left-truncate to keep_len
|
||||||
|
if input_ids_tensor.shape[1] > keep_len:
|
||||||
|
input_ids_tensor = input_ids_tensor[:, -keep_len:]
|
||||||
|
|
||||||
|
return input_ids_tensor
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
from ..generate import trim_input_tensor # noqa: F401
|
||||||
from .utils import KVCache
|
from .utils import KVCache
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
from ..generate import trim_input_tensor # noqa: F401
|
||||||
from .utils import KVCache
|
from .utils import KVCache
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
54
pkg/llms_from_scratch/tests/test_generate.py
Normal file
54
pkg/llms_from_scratch/tests/test_generate.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||||
|
# Source for "Build a Large Language Model From Scratch"
|
||||||
|
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||||
|
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||||
|
|
||||||
|
from llms_from_scratch.ch02 import create_dataloader_v1
|
||||||
|
|
||||||
|
import os
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
|
||||||
|
def test_dataloader(tmp_path, file_name):
|
||||||
|
|
||||||
|
if not os.path.exists("the-verdict.txt"):
|
||||||
|
url = ("https://raw.githubusercontent.com/rasbt/"
|
||||||
|
"LLMs-from-scratch/main/ch02/01_main-chapter-code/"
|
||||||
|
"the-verdict.txt")
|
||||||
|
file_path = "the-verdict.txt"
|
||||||
|
urllib.request.urlretrieve(url, file_path)
|
||||||
|
|
||||||
|
with open("the-verdict.txt", "r", encoding="utf-8") as f:
|
||||||
|
raw_text = f.read()
|
||||||
|
|
||||||
|
vocab_size = 50257
|
||||||
|
output_dim = 256
|
||||||
|
context_length = 1024
|
||||||
|
|
||||||
|
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
|
||||||
|
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
|
||||||
|
|
||||||
|
batch_size = 8
|
||||||
|
max_length = 4
|
||||||
|
dataloader = create_dataloader_v1(
|
||||||
|
raw_text,
|
||||||
|
batch_size=batch_size,
|
||||||
|
max_length=max_length,
|
||||||
|
stride=max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch in dataloader:
|
||||||
|
x, y = batch
|
||||||
|
|
||||||
|
token_embeddings = token_embedding_layer(x)
|
||||||
|
pos_embeddings = pos_embedding_layer(torch.arange(max_length))
|
||||||
|
|
||||||
|
input_embeddings = token_embeddings + pos_embeddings
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
|
input_embeddings.shape == torch.Size([8, 4, 256])
|
||||||
Reference in New Issue
Block a user