mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
rename create_dataloader to create_dataloader_v1
This commit is contained in:
@@ -1159,7 +1159,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
|
||||
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
|
||||
" # Initialize the tokenizer\n",
|
||||
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||
"\n",
|
||||
@@ -1206,7 +1206,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataloader = create_dataloader(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
|
||||
"dataloader = create_dataloader_v1(raw_text, batch_size=1, max_length=4, stride=1, shuffle=False)\n",
|
||||
"\n",
|
||||
"data_iter = iter(dataloader)\n",
|
||||
"first_batch = next(data_iter)\n",
|
||||
@@ -1274,7 +1274,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=4, stride=5, shuffle=False)\n",
|
||||
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=4, stride=5, shuffle=False)\n",
|
||||
"\n",
|
||||
"data_iter = iter(dataloader)\n",
|
||||
"inputs, targets = next(data_iter)\n",
|
||||
@@ -1484,7 +1484,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"max_length = 4\n",
|
||||
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
|
||||
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=5, shuffle=False)\n",
|
||||
"data_iter = iter(dataloader)\n",
|
||||
"inputs, targets = next(data_iter)"
|
||||
]
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
" return self.input_ids[idx], self.target_ids[idx]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
|
||||
"def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
|
||||
" # Initialize the tokenizer\n",
|
||||
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
|
||||
"\n",
|
||||
@@ -107,7 +107,7 @@
|
||||
"pos_embedding_layer = torch.nn.Embedding(block_size, output_dim)\n",
|
||||
"\n",
|
||||
"max_length = 4\n",
|
||||
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
|
||||
"dataloader = create_dataloader_v1(raw_text, batch_size=8, max_length=max_length, stride=5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user