Auto download DPO dataset if not already available in path (#479)

* Auto download DPO dataset if not already available in path

* update tests to account for latest HF transformers release in unit tests

* pep 8
This commit is contained in:
Sebastian Raschka
2025-01-12 12:27:28 -06:00
committed by GitHub
parent a48f9c7fe2
commit 4bfbcd069d
3 changed files with 66 additions and 89 deletions

View File

@@ -230,13 +230,34 @@
],
"source": [
"import json\n",
"import os\n",
"import urllib\n",
"\n",
"\n",
"def download_and_load_file(file_path, url):\n",
"\n",
" if not os.path.exists(file_path):\n",
" with urllib.request.urlopen(url) as response:\n",
" text_data = response.read().decode(\"utf-8\")\n",
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
" file.write(text_data)\n",
" else:\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" text_data = file.read()\n",
"\n",
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data = json.load(file)\n",
"\n",
" return data\n",
"\n",
"\n",
"file_path = \"instruction-data-with-preference.json\"\n",
"url = (\n",
" \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch\"\n",
" \"/main/ch07/04_preference-tuning-with-dpo/instruction-data-with-preference.json\"\n",
")\n",
"\n",
"with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
" data = json.load(file)\n",
"\n",
"data = download_and_load_file(file_path, url)\n",
"print(\"Number of entries:\", len(data))"
]
},
@@ -1546,7 +1567,6 @@
},
"outputs": [],
"source": [
"import os\n",
"from pathlib import Path\n",
"import shutil\n",
"\n",