mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Make datesets and loaders compatible with multiprocessing (#118)
This commit is contained in:
committed by
GitHub
parent
8fe63a9a0e
commit
bae4b0fb08
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
0
ch05/01_main-chapter-code/.gitignore
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -473,7 +473,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=True,\n",
|
||||
" shuffle=True\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = create_dataloader_v1(\n",
|
||||
@@ -482,7 +483,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=False,\n",
|
||||
" shuffle=False\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -697,7 +699,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=True,\n",
|
||||
" shuffle=True\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = create_dataloader_v1(\n",
|
||||
@@ -706,7 +709,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=False,\n",
|
||||
" shuffle=False\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -945,7 +949,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
|
||||
import os
|
||||
import torch
|
||||
import urllib.request
|
||||
import tiktoken
|
||||
|
||||
|
||||
# Import from local files
|
||||
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
|
||||
@@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
||||
|
||||
|
||||
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, start_context):
|
||||
eval_freq, eval_iter, start_context, tokenizer):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||
tokens_seen = 0
|
||||
@@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
|
||||
|
||||
# Print a sample text after each epoch
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
model, tokenizer, device, start_context
|
||||
)
|
||||
|
||||
return train_losses, val_losses, track_tokens_seen
|
||||
@@ -169,7 +171,8 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=True,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
val_loader = create_dataloader_v1(
|
||||
@@ -178,17 +181,20 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=False,
|
||||
shuffle=False
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
##############################
|
||||
# Train model
|
||||
##############################
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
||||
start_context="Every effort moves you",
|
||||
start_context="Every effort moves you", tokenizer=tokenizer
|
||||
)
|
||||
|
||||
return train_losses, val_losses, tokens_seen, model
|
||||
|
||||
@@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
# Tokenize the entire text
|
||||
token_ids = self.tokenizer.encode(txt)
|
||||
token_ids = tokenizer.encode(txt)
|
||||
|
||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
@@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
Reference in New Issue
Block a user