mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
drop_last=True
This commit is contained in:
@@ -35,7 +35,8 @@ class GPTDatasetV1(Dataset):
|
||||
return self.input_ids[idx], self.target_ids[idx]
|
||||
|
||||
|
||||
def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@@ -43,7 +44,8 @@ def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=Tru
|
||||
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
Reference in New Issue
Block a user