mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Make datesets and loaders compatible with multiprocessing (#118)
This commit is contained in:
committed by
GitHub
parent
8fe63a9a0e
commit
bae4b0fb08
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -473,7 +473,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=True,\n",
|
||||
" shuffle=True\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = create_dataloader_v1(\n",
|
||||
@@ -482,7 +483,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=False,\n",
|
||||
" shuffle=False\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -697,7 +699,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=True,\n",
|
||||
" shuffle=True\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = create_dataloader_v1(\n",
|
||||
@@ -706,7 +709,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=False,\n",
|
||||
" shuffle=False\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -945,7 +949,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
|
||||
import os
|
||||
import torch
|
||||
import urllib.request
|
||||
import tiktoken
|
||||
|
||||
|
||||
# Import from local files
|
||||
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
|
||||
@@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
||||
|
||||
|
||||
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, start_context):
|
||||
eval_freq, eval_iter, start_context, tokenizer):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||
tokens_seen = 0
|
||||
@@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
|
||||
|
||||
# Print a sample text after each epoch
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
model, tokenizer, device, start_context
|
||||
)
|
||||
|
||||
return train_losses, val_losses, track_tokens_seen
|
||||
@@ -169,7 +171,8 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=True,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
val_loader = create_dataloader_v1(
|
||||
@@ -178,17 +181,20 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=False,
|
||||
shuffle=False
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
##############################
|
||||
# Train model
|
||||
##############################
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
||||
start_context="Every effort moves you",
|
||||
start_context="Every effort moves you", tokenizer=tokenizer
|
||||
)
|
||||
|
||||
return train_losses, val_losses, tokens_seen, model
|
||||
|
||||
@@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
# Tokenize the entire text
|
||||
token_ids = self.tokenizer.encode(txt)
|
||||
token_ids = tokenizer.encode(txt)
|
||||
|
||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
@@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
@@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
# Tokenize the entire text
|
||||
token_ids = self.tokenizer.encode(txt)
|
||||
token_ids = tokenizer.encode(txt)
|
||||
|
||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
@@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
import time
|
||||
import tiktoken
|
||||
import torch
|
||||
from previous_chapters import (
|
||||
create_dataloader_v1,
|
||||
@@ -32,7 +33,7 @@ def read_text_file(file_path):
|
||||
return text_data
|
||||
|
||||
|
||||
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
||||
def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride, num_workers=0):
|
||||
split_idx = int(train_ratio * len(text_data))
|
||||
train_loader = create_dataloader_v1(
|
||||
text_data[:split_idx],
|
||||
@@ -40,7 +41,8 @@ def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
drop_last=True,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=num_workers
|
||||
)
|
||||
val_loader = create_dataloader_v1(
|
||||
text_data[split_idx:],
|
||||
@@ -48,7 +50,8 @@ def create_dataloaders(text_data, train_ratio, batch_size, max_length, stride):
|
||||
max_length=max_length,
|
||||
stride=stride,
|
||||
drop_last=False,
|
||||
shuffle=False
|
||||
shuffle=False,
|
||||
num_workers=num_workers
|
||||
)
|
||||
return train_loader, val_loader
|
||||
|
||||
@@ -78,7 +81,7 @@ def print_eta(start_time, book_start_time, index, total_files):
|
||||
|
||||
def train_model_simple(model, optimizer, device, n_epochs,
|
||||
eval_freq, eval_iter, print_sample_iter, start_context,
|
||||
output_dir, save_ckpt_freq,
|
||||
output_dir, save_ckpt_freq, tokenizer,
|
||||
batch_size=1024, train_ratio=0.90):
|
||||
|
||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||
@@ -101,7 +104,8 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
||||
train_ratio=train_ratio,
|
||||
batch_size=batch_size,
|
||||
max_length=GPT_CONFIG_124M["context_length"],
|
||||
stride=GPT_CONFIG_124M["context_length"]
|
||||
stride=GPT_CONFIG_124M["context_length"],
|
||||
num_workers=0
|
||||
)
|
||||
print("Training ...")
|
||||
model.train()
|
||||
@@ -126,7 +130,7 @@ def train_model_simple(model, optimizer, device, n_epochs,
|
||||
# Generate text passage
|
||||
if global_step % print_sample_iter == 0:
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
model, tokenizer, device, start_context
|
||||
)
|
||||
|
||||
if global_step % save_ckpt_freq:
|
||||
@@ -196,6 +200,7 @@ if __name__ == "__main__":
|
||||
model = GPTModel(GPT_CONFIG_124M)
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.1)
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
data_dir = args.data_dir
|
||||
all_files = [os.path.join(path, name) for path, subdirs, files
|
||||
@@ -221,6 +226,7 @@ if __name__ == "__main__":
|
||||
output_dir=output_dir,
|
||||
save_ckpt_freq=args.save_ckpt_freq,
|
||||
start_context="Every effort moves you",
|
||||
tokenizer=tokenizer
|
||||
)
|
||||
|
||||
epochs_tensor = torch.linspace(0, args.n_epochs, len(train_losses))
|
||||
|
||||
@@ -21,11 +21,10 @@ import matplotlib.pyplot as plt
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
|
||||
token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})
|
||||
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
input_chunk = token_ids[i:i + max_length]
|
||||
@@ -41,11 +40,11 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
import tiktoken
|
||||
import torch
|
||||
from previous_chapters import GPTModel, create_dataloader_v1
|
||||
|
||||
@@ -58,7 +59,7 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter):
|
||||
|
||||
def train_model(model, train_loader, val_loader, optimizer, device,
|
||||
n_epochs, eval_freq, eval_iter,
|
||||
encoded_start_context, warmup_iters=10,
|
||||
encoded_start_context, tokenizer, warmup_iters=10,
|
||||
initial_lr=3e-05, min_lr=1e-6):
|
||||
global_step = 0
|
||||
|
||||
@@ -120,6 +121,7 @@ if __name__ == "__main__":
|
||||
with open(os.path.join(script_dir, "the-verdict.txt"), "r", encoding="utf-8") as file:
|
||||
text_data = file.read()
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
train_ratio = 0.95
|
||||
@@ -155,7 +157,8 @@ if __name__ == "__main__":
|
||||
max_length=GPT_CONFIG_124M["context_length"],
|
||||
stride=GPT_CONFIG_124M["context_length"],
|
||||
drop_last=True,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
val_loader = create_dataloader_v1(
|
||||
@@ -164,7 +167,8 @@ if __name__ == "__main__":
|
||||
max_length=GPT_CONFIG_124M["context_length"],
|
||||
stride=GPT_CONFIG_124M["context_length"],
|
||||
drop_last=False,
|
||||
shuffle=False
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
model = GPTModel(GPT_CONFIG_124M)
|
||||
@@ -176,7 +180,7 @@ if __name__ == "__main__":
|
||||
weight_decay=HPARAM_CONFIG["weight_decay"]
|
||||
)
|
||||
|
||||
encoded_start_context = train_loader.dataset.tokenizer.encode("Nevertheless")
|
||||
encoded_start_context = tokenizer.encode("Nevertheless")
|
||||
encoded_tensor = torch.tensor(encoded_start_context).unsqueeze(0)
|
||||
|
||||
train_loss, val_loss = train_model(
|
||||
@@ -184,6 +188,7 @@ if __name__ == "__main__":
|
||||
n_epochs=HPARAM_CONFIG["n_epochs"],
|
||||
eval_freq=5, eval_iter=1,
|
||||
encoded_start_context=encoded_tensor,
|
||||
tokenizer=tokenizer,
|
||||
warmup_iters=HPARAM_CONFIG["warmup_iters"],
|
||||
initial_lr=HPARAM_CONFIG["initial_lr"],
|
||||
min_lr=HPARAM_CONFIG["min_lr"]
|
||||
|
||||
@@ -19,12 +19,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
# Tokenize the entire text
|
||||
token_ids = self.tokenizer.encode(txt)
|
||||
token_ids = tokenizer.encode(txt)
|
||||
|
||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
@@ -46,11 +45,11 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
# Create dataset
|
||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride, num_workers=0)
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
Reference in New Issue
Block a user