drop_last=True

This commit is contained in:
rasbt
2024-02-25 07:23:38 -06:00
parent 6243726ab3
commit cdcd73ba7f
4 changed files with 21 additions and 14 deletions

View File

@@ -35,7 +35,8 @@ class GPTDatasetV1(Dataset):
return self.input_ids[idx], self.target_ids[idx]
def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
@@ -43,7 +44,8 @@ def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=Tru
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
# Create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
return dataloader