Make datesets and loaders compatible with multiprocessing (#118)

This commit is contained in:
Sebastian Raschka
2024-04-13 14:57:56 -04:00
committed by GitHub
parent 8fe63a9a0e
commit bae4b0fb08
17 changed files with 140 additions and 116 deletions

View File

@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
import os
import torch
import urllib.request
import tiktoken
# Import from local files
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
@@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, start_context):
eval_freq, eval_iter, start_context, tokenizer):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, track_tokens_seen = [], [], []
tokens_seen = 0
@@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
# Print a sample text after each epoch
generate_and_print_sample(
model, train_loader.dataset.tokenizer, device, start_context
model, tokenizer, device, start_context
)
return train_losses, val_losses, track_tokens_seen
@@ -169,7 +171,8 @@ def main(gpt_config, settings):
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=True,
shuffle=True
shuffle=True,
num_workers=0
)
val_loader = create_dataloader_v1(
@@ -178,17 +181,20 @@ def main(gpt_config, settings):
max_length=gpt_config["context_length"],
stride=gpt_config["context_length"],
drop_last=False,
shuffle=False
shuffle=False,
num_workers=0
)
##############################
# Train model
##############################
tokenizer = tiktoken.get_encoding("gpt2")
train_losses, val_losses, tokens_seen = train_model_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
start_context="Every effort moves you",
start_context="Every effort moves you", tokenizer=tokenizer
)
return train_losses, val_losses, tokens_seen, model