mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Make datesets and loaders compatible with multiprocessing (#118)
This commit is contained in:
committed by
GitHub
parent
8fe63a9a0e
commit
bae4b0fb08
@@ -47,7 +47,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch version: 2.2.1\n"
|
||||
"torch version: 2.2.2\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -130,7 +130,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=True,\n",
|
||||
" shuffle=True\n",
|
||||
" shuffle=True,\n",
|
||||
" num_workers=0\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"val_loader = create_dataloader_v1(\n",
|
||||
@@ -139,7 +140,8 @@
|
||||
" max_length=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" stride=GPT_CONFIG_124M[\"context_length\"],\n",
|
||||
" drop_last=False,\n",
|
||||
" shuffle=False\n",
|
||||
" shuffle=False,\n",
|
||||
" num_workers=0\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -500,7 +502,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"def train_model(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
|
||||
" eval_freq, eval_iter, start_context, warmup_steps=10,\n",
|
||||
" eval_freq, eval_iter, start_context, tokenizer, warmup_steps=10,\n",
|
||||
" initial_lr=3e-05, min_lr=1e-6):\n",
|
||||
"\n",
|
||||
" train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []\n",
|
||||
@@ -562,8 +564,7 @@
|
||||
"\n",
|
||||
" # Generate and print a sample from the model to monitor progress\n",
|
||||
" generate_and_print_sample(\n",
|
||||
" model, train_loader.dataset.tokenizer,\n",
|
||||
" device, start_context\n",
|
||||
" model, tokenizer, device, start_context\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return train_losses, val_losses, track_tokens_seen, track_lrs"
|
||||
@@ -625,18 +626,21 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import tiktoken\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"model = GPTModel(GPT_CONFIG_124M)\n",
|
||||
"model.to(device)\n",
|
||||
"\n",
|
||||
"peak_lr = 5e-4\n",
|
||||
"optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n",
|
||||
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||
"\n",
|
||||
"n_epochs = 15\n",
|
||||
"train_losses, val_losses, tokens_seen, lrs = train_model(\n",
|
||||
" model, train_loader, val_loader, optimizer, device, n_epochs=n_epochs,\n",
|
||||
" eval_freq=5, eval_iter=1, start_context=\"Every effort moves you\",\n",
|
||||
" warmup_steps=10, initial_lr=1e-5, min_lr=1e-5\n",
|
||||
" tokenizer=tokenizer, warmup_steps=10, initial_lr=1e-5, min_lr=1e-5\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
@@ -705,7 +709,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/jg/tpqyh1fd5js5wsr1d138k3n40000gn/T/ipykernel_34986/3589549395.py:5: UserWarning: The figure layout has changed to tight\n",
|
||||
"/var/folders/jg/tpqyh1fd5js5wsr1d138k3n40000gn/T/ipykernel_9436/3589549395.py:5: UserWarning: The figure layout has changed to tight\n",
|
||||
" plt.tight_layout(); plt.savefig(\"3.pdf\")\n"
|
||||
]
|
||||
},
|
||||
@@ -755,7 +759,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -20,12 +20,11 @@ import matplotlib.pyplot as plt
|
||||
|
||||
class GPTDatasetV1(Dataset):
|
||||
def __init__(self, txt, tokenizer, max_length, stride):
|
||||
self.tokenizer = tokenizer
|
||||
self.input_ids = []
|
||||
self.target_ids = []
|
||||
|
||||
# Tokenize the entire text
|
||||
token_ids = self.tokenizer.encode(txt)
|
||||
token_ids = tokenizer.encode(txt)
|
||||
|
||||
# Use a sliding window to chunk the book into overlapping sequences of max_length
|
||||
for i in range(0, len(token_ids) - max_length, stride):
|
||||
@@ -42,7 +41,7 @@ class GPTDatasetV1(Dataset):
|
||||
|
||||
|
||||
def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
stride=128, shuffle=True, drop_last=True):
|
||||
stride=128, shuffle=True, drop_last=True, num_workers=0):
|
||||
# Initialize the tokenizer
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
@@ -51,7 +50,7 @@ def create_dataloader_v1(txt, batch_size=4, max_length=256,
|
||||
|
||||
# Create dataloader
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
|
||||
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)
|
||||
|
||||
return dataloader
|
||||
|
||||
|
||||
Reference in New Issue
Block a user