Make datesets and loaders compatible with multiprocessing (#118)

This commit is contained in:
Sebastian Raschka
2024-04-13 14:57:56 -04:00
committed by GitHub
parent 8fe63a9a0e
commit bae4b0fb08
17 changed files with 140 additions and 116 deletions

View File

@@ -13,13 +13,12 @@ from torch.utils.data import Dataset, DataLoader
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
def __init__(self, txt, tokenizer, max_length, stride, num_workers=0):
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = self.tokenizer.encode(txt)
token_ids = tokenizer.encode(txt)
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
@@ -36,7 +35,7 @@ class GPTDatasetV1(Dataset):
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
stride=128, shuffle=True, drop_last=True, num_workers=0):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")

View File

@@ -11,7 +11,6 @@ from torch.utils.data import Dataset, DataLoader
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.tokenizer = tokenizer
self.input_ids = []
self.target_ids = []
@@ -33,7 +32,7 @@ class GPTDatasetV1(Dataset):
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
stride=128, shuffle=True, drop_last=True, num_workers=0):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
@@ -42,7 +41,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
return dataloader