fixed num_workers (#229)

* fixed num_workers

* ch06 & ch07: added num_workers to create_dataloader_v1
This commit is contained in:
Daniel Kleine
2024-06-20 00:36:46 +02:00
committed by GitHub
parent 24523bd34d
commit bbb2a0c3d5
15 changed files with 20 additions and 20 deletions

View File

@@ -13,7 +13,7 @@ from torch.utils.data import Dataset, DataLoader
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride, num_workers=0):
def __init__(self, txt, tokenizer, max_length, stride):
self.input_ids = []
self.target_ids = []
@@ -44,7 +44,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
return dataloader

View File

@@ -41,7 +41,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
return dataloader

View File

@@ -49,7 +49,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
return dataloader