Make datesets and loaders compatible with multiprocessing (#118)

This commit is contained in:
Sebastian Raschka
2024-04-13 14:57:56 -04:00
committed by GitHub
parent 9f3f231ac7
commit dd51d4ad83
17 changed files with 140 additions and 116 deletions

0
ch05/01_main-chapter-code/.gitignore vendored Normal file
View File

File diff suppressed because one or more lines are too long

View File

@@ -473,7 +473,8 @@
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
" shuffle=True,\n",
" num_workers=0\n",
")\n",
"\n",
"val_loader = create_dataloader_v1(\n",
@@ -482,7 +483,8 @@
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
" shuffle=False,\n",
" num_workers=0\n",
")"
]
},
@@ -697,7 +699,8 @@
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=True,\n",
" shuffle=True\n",
" shuffle=True,\n",
" num_workers=0\n",
")\n",
"\n",
"val_loader = create_dataloader_v1(\n",
@@ -706,7 +709,8 @@
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
" drop_last=False,\n",
" shuffle=False\n",
" shuffle=False,\n",
" num_workers=0\n",
")"
]
},
@@ -945,7 +949,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
import os
import torch
import urllib.request
import tiktoken
# Import from local files
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
@@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, start_context):
eval_freq, eval_iter, start_context, tokenizer):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, track_tokens_seen = [], [], []
tokens_seen = 0
@@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
# Print a sample text after each epoch
generate_and_print_sample(
model, train_loader.dataset.tokenizer, device, start_context
model, tokenizer, device, start_context
)
return train_losses, val_losses, track_tokens_seen
@@ -169,7 +171,8 @@ def main(gpt_config, settings):
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=True,
shuffle=True
shuffle=True,
num_workers=0
)
val_loader = create_dataloader_v1(
@@ -178,17 +181,20 @@ def main(gpt_config, settings):
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=False,
shuffle=False
shuffle=False,
num_workers=0
)
##############################
# Train model
##############################
tokenizer = tiktoken.get_encoding("gpt2")
train_losses, val_losses, tokens_seen = train_model_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
start_context="Every effort moves you",
start_context="Every effort moves you", tokenizer=tokenizer
)
return train_losses, val_losses, tokens_seen, model

View File

@@ -14,12 +14,11 @@ from torch.utils.data import Dataset, DataLoader
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = self.tokenizer.encode(txt)
token_ids = tokenizer.encode(txt)
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
@@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
stride=128, shuffle=True, drop_last=True, num_workers=0):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
@@ -45,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
return dataloader