Make datesets and loaders compatible with multiprocessing (#118)

This commit is contained in:
Sebastian Raschka
2024-04-13 14:57:56 -04:00
committed by GitHub
parent 8fe63a9a0e
commit bae4b0fb08
17 changed files with 140 additions and 116 deletions

View File

@@ -37,7 +37,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.2.1\n",
"torch version: 2.2.2\n",
"tiktoken version: 0.5.1\n"
]
}
@@ -724,7 +724,7 @@
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[16], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m SimpleTokenizerV1(vocab)\n\u001b[1;32m 3\u001b[0m text \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mHello, do you like tea. Is this-- a test?\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mtokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mencode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstr_to_int[s] \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36mSimpleTokenizerV1.encode\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43ms\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpreprocessed\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
"Cell \u001b[0;32mIn[12], line 9\u001b[0m, in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 7\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m re\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m([,.?_!\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m()\u001b[39m\u001b[38;5;130;01m\\'\u001b[39;00m\u001b[38;5;124m]|--|\u001b[39m\u001b[38;5;124m\\\u001b[39m\u001b[38;5;124ms)\u001b[39m\u001b[38;5;124m'\u001b[39m, text)\n\u001b[1;32m 8\u001b[0m preprocessed \u001b[38;5;241m=\u001b[39m [item\u001b[38;5;241m.\u001b[39mstrip() \u001b[38;5;28;01mfor\u001b[39;00m item \u001b[38;5;129;01min\u001b[39;00m preprocessed \u001b[38;5;28;01mif\u001b[39;00m item\u001b[38;5;241m.\u001b[39mstrip()]\n\u001b[0;32m----> 9\u001b[0m ids \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstr_to_int\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m s \u001b[38;5;129;01min\u001b[39;00m preprocessed]\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m ids\n",
"\u001b[0;31mKeyError\u001b[0m: 'Hello'"
]
@@ -957,7 +957,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 24,
"id": "ede1d41f-934b-4bf4-8184-54394a257a94",
"metadata": {},
"outputs": [],
@@ -967,7 +967,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 25,
"id": "48967a77-7d17-42bf-9e92-fc619d63a59e",
"metadata": {},
"outputs": [
@@ -988,7 +988,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 26,
"id": "6ad3312f-a5f7-4efc-9d7d-8ea09d7b5128",
"metadata": {},
"outputs": [],
@@ -998,7 +998,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 27,
"id": "5ff2cd85-7cfb-4325-b390-219938589428",
"metadata": {},
"outputs": [
@@ -1020,7 +1020,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 28,
"id": "d26a48bb-f82e-41a8-a955-a1c9cf9d50ab",
"metadata": {},
"outputs": [
@@ -1080,7 +1080,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 29,
"id": "848d5ade-fd1f-46c3-9e31-1426e315c71b",
"metadata": {},
"outputs": [
@@ -1111,7 +1111,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 30,
"id": "e84424a7-646d-45b6-99e3-80d15fb761f2",
"metadata": {},
"outputs": [],
@@ -1121,7 +1121,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 31,
"id": "dfbff852-a92f-48c8-a46d-143a0f109f40",
"metadata": {},
"outputs": [
@@ -1154,7 +1154,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 32,
"id": "d97b031e-ed55-409d-95f2-aeb38c6fe366",
"metadata": {},
"outputs": [
@@ -1179,7 +1179,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 33,
"id": "f57bd746-dcbf-4433-8e24-ee213a8c34a1",
"metadata": {},
"outputs": [
@@ -1221,7 +1221,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 34,
"id": "e1770134-e7f3-4725-a679-e04c3be48cac",
"metadata": {},
"outputs": [
@@ -1229,7 +1229,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.1.0\n"
"PyTorch version: 2.2.2\n"
]
}
],
@@ -1258,7 +1258,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 35,
"id": "74b41073-4c9f-46e2-a1bd-d38e4122b375",
"metadata": {},
"outputs": [],
@@ -1268,12 +1268,11 @@
"\n",
"class GPTDatasetV1(Dataset):\n",
" def __init__(self, txt, tokenizer, max_length, stride):\n",
" self.tokenizer = tokenizer\n",
" self.input_ids = []\n",
" self.target_ids = []\n",
"\n",
" # Tokenize the entire text\n",
" token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
" token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n",
"\n",
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
" for i in range(0, len(token_ids) - max_length, stride):\n",
@@ -1291,12 +1290,12 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 36,
"id": "5eb30ebe-97b3-43c5-9ff1-a97d621b3c4e",
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):\n",
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):\n",
"\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
@@ -1306,7 +1305,12 @@
"\n",
" # Create dataloader\n",
" dataloader = DataLoader(\n",
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
" dataset,\n",
" batch_size=batch_size,\n",
" shuffle=shuffle,\n",
" drop_last=drop_last,\n",
" num_workers=0\n",
" )\n",
"\n",
" return dataloader"
]
@@ -1321,7 +1325,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 37,
"id": "df31d96c-6bfd-4564-a956-6192242d7579",
"metadata": {},
"outputs": [],
@@ -1332,7 +1336,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 38,
"id": "9226d00c-ad9a-4949-a6e4-9afccfc7214f",
"metadata": {},
"outputs": [
@@ -1354,7 +1358,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 39,
"id": "10deb4bc-4de1-4d20-921e-4b1c7a0e1a6d",
"metadata": {},
"outputs": [
@@ -1398,7 +1402,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 40,
"id": "1916e7a6-f03d-4f09-91a6-d0bdbac5a58c",
"metadata": {},
"outputs": [
@@ -1473,7 +1477,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 41,
"id": "15a6304c-9474-4470-b85d-3991a49fa653",
"metadata": {},
"outputs": [],
@@ -1491,7 +1495,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 42,
"id": "93cb2cee-9aa6-4bb8-8977-c65661d16eda",
"metadata": {},
"outputs": [],
@@ -1513,7 +1517,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 43,
"id": "a686eb61-e737-4351-8f1c-222913d47468",
"metadata": {},
"outputs": [
@@ -1554,7 +1558,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 44,
"id": "e43600ba-f287-4746-8ddf-d0f71a9023ca",
"metadata": {},
"outputs": [
@@ -1581,7 +1585,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 45,
"id": "50280ead-0363-44c8-8c35-bb885d92c8b7",
"metadata": {},
"outputs": [
@@ -1874,7 +1878,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
"metadata": {},
"outputs": [],
@@ -43,12 +43,11 @@
"\n",
"class GPTDatasetV1(Dataset):\n",
" def __init__(self, txt, tokenizer, max_length, stride):\n",
" self.tokenizer = tokenizer\n",
" self.input_ids = []\n",
" self.target_ids = []\n",
"\n",
" # Tokenize the entire text\n",
" token_ids = self.tokenizer.encode(txt, allowed_special={'<|endoftext|>'})\n",
" token_ids = tokenizer.encode(txt, allowed_special={\"<|endoftext|>\"})\n",
"\n",
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
" for i in range(0, len(token_ids) - max_length, stride):\n",
@@ -65,7 +64,7 @@
"\n",
"\n",
"def create_dataloader_v1(txt, batch_size=4, max_length=256, \n",
" stride=128, shuffle=True, drop_last=True):\n",
" stride=128, shuffle=True, drop_last=True, num_workers=0):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
@@ -74,7 +73,7 @@
"\n",
" # Create dataloader\n",
" dataloader = DataLoader(\n",
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)\n",
" dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=0)\n",
"\n",
" return dataloader\n",
"\n",
@@ -99,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
"metadata": {},
"outputs": [],
@@ -117,7 +116,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
"metadata": {},
"outputs": [
@@ -150,7 +149,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.11.4"
}
},
"nbformat": 4,