mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Add defensive context trimming for multiturn (#815)
* Add defensive context trimming for multiturn * add all mods
This commit is contained in:
committed by
GitHub
parent
215abdbcdd
commit
c7a4362ca4
17
pkg/llms_from_scratch/generate.py
Normal file
17
pkg/llms_from_scratch/generate.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
# Additional utility and helper functions for text generation not covered
|
||||
# in the main chapters
|
||||
|
||||
def trim_input_tensor(input_ids_tensor, context_len, max_new_tokens):
|
||||
assert max_new_tokens < context_len
|
||||
keep_len = max(1, context_len - max_new_tokens)
|
||||
|
||||
# If the prompt is too long, left-truncate to keep_len
|
||||
if input_ids_tensor.shape[1] > keep_len:
|
||||
input_ids_tensor = input_ids_tensor[:, -keep_len:]
|
||||
|
||||
return input_ids_tensor
|
||||
@@ -3,6 +3,7 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..generate import trim_input_tensor # noqa: F401
|
||||
from .utils import KVCache
|
||||
import torch
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from ..generate import trim_input_tensor # noqa: F401
|
||||
from .utils import KVCache
|
||||
import torch
|
||||
|
||||
|
||||
54
pkg/llms_from_scratch/tests/test_generate.py
Normal file
54
pkg/llms_from_scratch/tests/test_generate.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
|
||||
# Source for "Build a Large Language Model From Scratch"
|
||||
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
|
||||
# Code: https://github.com/rasbt/LLMs-from-scratch
|
||||
|
||||
from llms_from_scratch.ch02 import create_dataloader_v1
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
@pytest.mark.parametrize("file_name", ["the-verdict.txt"])
|
||||
def test_dataloader(tmp_path, file_name):
|
||||
|
||||
if not os.path.exists("the-verdict.txt"):
|
||||
url = ("https://raw.githubusercontent.com/rasbt/"
|
||||
"LLMs-from-scratch/main/ch02/01_main-chapter-code/"
|
||||
"the-verdict.txt")
|
||||
file_path = "the-verdict.txt"
|
||||
urllib.request.urlretrieve(url, file_path)
|
||||
|
||||
with open("the-verdict.txt", "r", encoding="utf-8") as f:
|
||||
raw_text = f.read()
|
||||
|
||||
vocab_size = 50257
|
||||
output_dim = 256
|
||||
context_length = 1024
|
||||
|
||||
token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)
|
||||
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)
|
||||
|
||||
batch_size = 8
|
||||
max_length = 4
|
||||
dataloader = create_dataloader_v1(
|
||||
raw_text,
|
||||
batch_size=batch_size,
|
||||
max_length=max_length,
|
||||
stride=max_length
|
||||
)
|
||||
|
||||
for batch in dataloader:
|
||||
x, y = batch
|
||||
|
||||
token_embeddings = token_embedding_layer(x)
|
||||
pos_embeddings = pos_embedding_layer(torch.arange(max_length))
|
||||
|
||||
input_embeddings = token_embeddings + pos_embeddings
|
||||
|
||||
break
|
||||
|
||||
input_embeddings.shape == torch.Size([8, 4, 256])
|
||||
Reference in New Issue
Block a user