rename create_dataloader to create_dataloader_v1

This commit is contained in:
rasbt
2024-01-24 07:02:05 -06:00
parent f6896d17ef
commit f27c9e6135
2 changed files with 6 additions and 6 deletions

View File

@@ -1159,7 +1159,7 @@
"metadata": {},
"outputs": [],
"source": [
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
@@ -1206,7 +1206,7 @@
}
],
"source": [
"dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
"dataloader = create_dataloader_v1(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
"\n",
"data_iter = iter(dataloader)\n",
"first_batch = next(data_iter)\n",
@@ -1274,7 +1274,7 @@
}
],
"source": [
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=5, shuffle=False)\n",
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=4, stride=5, shuffle=False)\n",
"\n",
"data_iter = iter(dataloader)\n",
"inputs, targets = next(data_iter)\n",
@@ -1484,7 +1484,7 @@
"outputs": [],
"source": [
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
"data_iter = iter(dataloader)\n",
"inputs, targets = next(data_iter)"
]

View File

@@ -78,7 +78,7 @@
" return self.input_ids[idx], self.target_ids[idx]\n",
"\n",
"\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
@@ -107,7 +107,7 @@
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=5)"
]
},
{