mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Make datesets and loaders compatible with multiprocessing (#118)
This commit is contained in:
committed by
GitHub
parent
8fe63a9a0e
commit
bae4b0fb08
@@ -7,6 +7,8 @@ import matplotlib.pyplot as plt
|
||||
import os
|
||||
import torch
|
||||
import urllib.request
|
||||
import tiktoken
|
||||
|
||||
|
||||
# Import from local files
|
||||
from previous_chapters import GPTModel, create_dataloader_v1, generate_text_simple
|
||||
@@ -69,7 +71,7 @@ def generate_and_print_sample(model, tokenizer, device, start_context):
|
||||
|
||||
|
||||
def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
|
||||
eval_freq, eval_iter, start_context):
|
||||
eval_freq, eval_iter, start_context, tokenizer):
|
||||
# Initialize lists to track losses and tokens seen
|
||||
train_losses, val_losses, track_tokens_seen = [], [], []
|
||||
tokens_seen = 0
|
||||
@@ -99,7 +101,7 @@ def train_model_simple(model, train_loader, val_loader, optimizer, device, num_e
|
||||
|
||||
# Print a sample text after each epoch
|
||||
generate_and_print_sample(
|
||||
model, train_loader.dataset.tokenizer, device, start_context
|
||||
model, tokenizer, device, start_context
|
||||
)
|
||||
|
||||
return train_losses, val_losses, track_tokens_seen
|
||||
@@ -169,7 +171,8 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=True,
|
||||
shuffle=True
|
||||
shuffle=True,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
val_loader = create_dataloader_v1(
|
||||
@@ -178,17 +181,20 @@ def main(gpt_config, settings):
|
||||
max_length=gpt_config["context_length"],
|
||||
stride=gpt_config["context_length"],
|
||||
drop_last=False,
|
||||
shuffle=False
|
||||
shuffle=False,
|
||||
num_workers=0
|
||||
)
|
||||
|
||||
##############################
|
||||
# Train model
|
||||
##############################
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_losses, val_losses, tokens_seen = train_model_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
num_epochs=settings["num_epochs"], eval_freq=5, eval_iter=1,
|
||||
start_context="Every effort moves you",
|
||||
start_context="Every effort moves you", tokenizer=tokenizer
|
||||
)
|
||||
|
||||
return train_losses, val_losses, tokens_seen, model
|
||||
|
||||
Reference in New Issue
Block a user