add toggle for qkv_bias

This commit is contained in:
rasbt
2024-01-17 07:50:57 -06:00
parent 0074c98968
commit 92896d817c
4 changed files with 114 additions and 108 deletions

View File

@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "93804da5-372b-45ff-9ef4-8398ba1dd78e",
"metadata": {},
"outputs": [
@@ -28,7 +28,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.0.1\n",
"torch version: 2.1.0\n",
"tiktoken version: 0.5.1\n"
]
}
@@ -78,7 +78,7 @@
" return self.input_ids[idx], self.target_ids[idx]\n",
"\n",
"\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
@@ -86,11 +86,12 @@
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
"\n",
" # Create dataloader\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)\n",
"\n",
" return dataloader\n",
"\n",
"\n",
"\n",
"with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()\n",
"\n",
@@ -144,14 +145,6 @@
"source": [
"print(input_embeddings.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2773c09d-c136-4372-a2be-04b58d292842",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@@ -170,7 +163,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.12"
}
},
"nbformat": 4,