mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2026-04-10 12:33:42 +00:00
Switch from urllib to requests to improve reliability (#867)
* Switch from urllib to requests to improve reliability * Keep ruff linter-specific * update * update * update
This commit is contained in:
committed by
GitHub
parent
8552565bda
commit
7bd263144e
@@ -169,10 +169,33 @@
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import urllib\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def download_and_load_file(file_path, url):\n",
|
||||
" if not os.path.exists(file_path):\n",
|
||||
" response = requests.get(url, timeout=30)\n",
|
||||
" response.raise_for_status()\n",
|
||||
" text_data = response.text\n",
|
||||
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||
" file.write(text_data)\n",
|
||||
"\n",
|
||||
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" data = json.load(file)\n",
|
||||
"\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# The book originally used the following code below\n",
|
||||
"# However, urllib uses older protocol settings that\n",
|
||||
"# can cause problems for some readers using a VPN.\n",
|
||||
"# The `requests` version above is more robust\n",
|
||||
"# in that regard.\n",
|
||||
"\n",
|
||||
"\"\"\"\n",
|
||||
"import urllib\n",
|
||||
"\n",
|
||||
"def download_and_load_file(file_path, url):\n",
|
||||
"\n",
|
||||
" if not os.path.exists(file_path):\n",
|
||||
" with urllib.request.urlopen(url) as response:\n",
|
||||
@@ -180,15 +203,15 @@
|
||||
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||
" file.write(text_data)\n",
|
||||
"\n",
|
||||
" # The book originally contained this unnecessary \"else\" clause:\n",
|
||||
" #else:\n",
|
||||
" # with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" # text_data = file.read()\n",
|
||||
" else:\n",
|
||||
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" text_data = file.read()\n",
|
||||
"\n",
|
||||
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" data = json.load(file)\n",
|
||||
"\n",
|
||||
" return data\n",
|
||||
"\"\"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"file_path = \"instruction-data.json\"\n",
|
||||
@@ -2490,7 +2513,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import requests # noqa: F811\n",
|
||||
"# import urllib.request\n",
|
||||
"\n",
|
||||
"def query_model(\n",
|
||||
" prompt,\n",
|
||||
@@ -2512,7 +2536,8 @@
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \"\"\"\n",
|
||||
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
|
||||
" payload = json.dumps(data).encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
@@ -2536,6 +2561,26 @@
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
"\n",
|
||||
" return response_data\n",
|
||||
" \"\"\"\n",
|
||||
"\n",
|
||||
" # The book originally used the commented-out above, which is based\n",
|
||||
" # on urllib. It works generally fine, but some readers reported\n",
|
||||
" # issues with using urlib when using a (company) VPN.\n",
|
||||
" # The code below uses the requests library, which doesn't seem\n",
|
||||
" # to have these issues.\n",
|
||||
"\n",
|
||||
" # Send the POST request\n",
|
||||
" with requests.post(url, json=data, stream=True, timeout=30) as r:\n",
|
||||
" r.raise_for_status()\n",
|
||||
" response_data = \"\"\n",
|
||||
" for line in r.iter_lines(decode_unicode=True):\n",
|
||||
" if not line:\n",
|
||||
" continue\n",
|
||||
" response_json = json.loads(line)\n",
|
||||
" if \"message\" in response_json:\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
"\n",
|
||||
" return response_data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model = \"llama3\"\n",
|
||||
|
||||
@@ -12,10 +12,10 @@ import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import requests
|
||||
import tiktoken
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -234,17 +234,17 @@ def custom_collate_with_masking_fn(
|
||||
|
||||
|
||||
def download_and_load_file(file_path, url):
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text_data = response.read().decode("utf-8")
|
||||
response = requests.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
text_data = response.text
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(text_data)
|
||||
else:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
text_data = file.read()
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
|
||||
return data
|
||||
|
||||
@@ -5,11 +5,10 @@
|
||||
|
||||
|
||||
import os
|
||||
import urllib.request
|
||||
|
||||
# import requests
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf
|
||||
from tqdm import tqdm
|
||||
|
||||
@@ -48,44 +47,40 @@ def download_and_load_gpt2(model_size, models_dir):
|
||||
|
||||
def download_file(url, destination, backup_url=None):
|
||||
def _attempt_download(download_url):
|
||||
with urllib.request.urlopen(download_url) as response:
|
||||
# Get the total file size from headers, defaulting to 0 if not present
|
||||
file_size = int(response.headers.get("Content-Length", 0))
|
||||
response = requests.get(download_url, stream=True, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Check if file exists and has the same size
|
||||
if os.path.exists(destination):
|
||||
file_size_local = os.path.getsize(destination)
|
||||
if file_size == file_size_local:
|
||||
print(f"File already exists and is up-to-date: {destination}")
|
||||
return True # Indicate success without re-downloading
|
||||
file_size = int(response.headers.get("Content-Length", 0))
|
||||
|
||||
block_size = 1024 # 1 Kilobyte
|
||||
# Check if file exists and has same size
|
||||
if os.path.exists(destination):
|
||||
file_size_local = os.path.getsize(destination)
|
||||
if file_size and file_size == file_size_local:
|
||||
print(f"File already exists and is up-to-date: {destination}")
|
||||
return True
|
||||
|
||||
# Initialize the progress bar with total file size
|
||||
progress_bar_description = os.path.basename(download_url)
|
||||
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
|
||||
with open(destination, "wb") as file:
|
||||
while True:
|
||||
chunk = response.read(block_size)
|
||||
if not chunk:
|
||||
break
|
||||
block_size = 1024 # 1 KB
|
||||
desc = os.path.basename(download_url)
|
||||
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=desc) as progress_bar:
|
||||
with open(destination, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=block_size):
|
||||
if chunk:
|
||||
file.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
return True
|
||||
return True
|
||||
|
||||
try:
|
||||
if _attempt_download(url):
|
||||
return
|
||||
except (urllib.error.HTTPError, urllib.error.URLError):
|
||||
except requests.exceptions.RequestException:
|
||||
if backup_url is not None:
|
||||
print(f"Primary URL ({url}) failed. Attempting backup URL: {backup_url}")
|
||||
try:
|
||||
if _attempt_download(backup_url):
|
||||
return
|
||||
except urllib.error.HTTPError:
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
# If we reach here, both attempts have failed
|
||||
error_message = (
|
||||
f"Failed to download from both primary URL ({url})"
|
||||
f"{' and backup URL (' + backup_url + ')' if backup_url else ''}."
|
||||
@@ -97,37 +92,6 @@ def download_file(url, destination, backup_url=None):
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
|
||||
# Alternative way using `requests`
|
||||
"""
|
||||
def download_file(url, destination):
|
||||
# Send a GET request to download the file in streaming mode
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
# Get the total file size from headers, defaulting to 0 if not present
|
||||
file_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
# Check if file exists and has the same size
|
||||
if os.path.exists(destination):
|
||||
file_size_local = os.path.getsize(destination)
|
||||
if file_size == file_size_local:
|
||||
print(f"File already exists and is up-to-date: {destination}")
|
||||
return
|
||||
|
||||
# Define the block size for reading the file
|
||||
block_size = 1024 # 1 Kilobyte
|
||||
|
||||
# Initialize the progress bar with total file size
|
||||
progress_bar_description = url.split("/")[-1] # Extract filename from URL
|
||||
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
|
||||
# Open the destination file in binary write mode
|
||||
with open(destination, "wb") as file:
|
||||
# Iterate over the file data in chunks
|
||||
for chunk in response.iter_content(block_size):
|
||||
progress_bar.update(len(chunk)) # Update progress bar
|
||||
file.write(chunk) # Write the chunk to the file
|
||||
"""
|
||||
|
||||
|
||||
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
|
||||
# Initialize parameters dictionary with empty blocks for each layer
|
||||
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
|
||||
|
||||
@@ -11,9 +11,9 @@ import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
import tiktoken
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
@@ -97,14 +97,14 @@ def custom_collate_fn(
|
||||
|
||||
|
||||
def download_and_load_file(file_path, url):
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
with urllib.request.urlopen(url) as response:
|
||||
text_data = response.read().decode("utf-8")
|
||||
response = requests.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
text_data = response.text
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(text_data)
|
||||
|
||||
with open(file_path, "r") as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
data = json.load(file)
|
||||
|
||||
return data
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
import json
|
||||
import psutil
|
||||
from tqdm import tqdm
|
||||
import urllib.request
|
||||
import requests
|
||||
|
||||
|
||||
def query_model(prompt, model="llama3", url="http://localhost:11434/api/chat"):
|
||||
@@ -25,23 +25,16 @@ def query_model(prompt, model="llama3", url="http://localhost:11434/api/chat"):
|
||||
}
|
||||
}
|
||||
|
||||
# Convert the dictionary to a JSON formatted string and encode it to bytes
|
||||
payload = json.dumps(data).encode("utf-8")
|
||||
|
||||
# Create a request object, setting the method to POST and adding necessary headers
|
||||
request = urllib.request.Request(url, data=payload, method="POST")
|
||||
request.add_header("Content-Type", "application/json")
|
||||
|
||||
# Send the request and capture the response
|
||||
response_data = ""
|
||||
with urllib.request.urlopen(request) as response:
|
||||
# Read and decode the response
|
||||
while True:
|
||||
line = response.readline().decode("utf-8")
|
||||
# Send the POST request
|
||||
with requests.post(url, json=data, stream=True, timeout=30) as r:
|
||||
r.raise_for_status()
|
||||
response_data = ""
|
||||
for line in r.iter_lines(decode_unicode=True):
|
||||
if not line:
|
||||
break
|
||||
continue
|
||||
response_json = json.loads(line)
|
||||
response_data += response_json["message"]["content"]
|
||||
if "message" in response_json:
|
||||
response_data += response_json["message"]["content"]
|
||||
|
||||
return response_data
|
||||
|
||||
|
||||
@@ -215,8 +215,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import json\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def query_model(prompt, model=\"llama3\", url=\"http://localhost:11434/api/chat\"):\n",
|
||||
@@ -236,27 +236,19 @@
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
|
||||
" payload = json.dumps(data).encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Create a request object, setting the method to POST and adding necessary headers\n",
|
||||
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
|
||||
" request.add_header(\"Content-Type\", \"application/json\")\n",
|
||||
"\n",
|
||||
" # Send the request and capture the response\n",
|
||||
" response_data = \"\"\n",
|
||||
" with urllib.request.urlopen(request) as response:\n",
|
||||
" # Read and decode the response\n",
|
||||
" while True:\n",
|
||||
" line = response.readline().decode(\"utf-8\")\n",
|
||||
" # Send the POST request\n",
|
||||
" with requests.post(url, json=data, stream=True, timeout=30) as r:\n",
|
||||
" r.raise_for_status()\n",
|
||||
" response_data = \"\"\n",
|
||||
" for line in r.iter_lines(decode_unicode=True):\n",
|
||||
" if not line:\n",
|
||||
" break\n",
|
||||
" continue\n",
|
||||
" response_json = json.loads(line)\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
" if \"message\" in response_json:\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
"\n",
|
||||
" return response_data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"result = query_model(\"What do Llamas eat?\")\n",
|
||||
"print(result)"
|
||||
]
|
||||
@@ -640,7 +632,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -274,8 +274,8 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import json\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def query_model(prompt, model=\"llama3.1:70b\", url=\"http://localhost:11434/api/chat\"):\n",
|
||||
@@ -294,23 +294,16 @@
|
||||
" }\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
|
||||
" payload = json.dumps(data).encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Create a request object, setting the method to POST and adding necessary headers\n",
|
||||
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
|
||||
" request.add_header(\"Content-Type\", \"application/json\")\n",
|
||||
"\n",
|
||||
" # Send the request and capture the response\n",
|
||||
" response_data = \"\"\n",
|
||||
" with urllib.request.urlopen(request) as response:\n",
|
||||
" # Read and decode the response\n",
|
||||
" while True:\n",
|
||||
" line = response.readline().decode(\"utf-8\")\n",
|
||||
" # Send the POST request\n",
|
||||
" with requests.post(url, json=data, stream=True, timeout=30) as r:\n",
|
||||
" r.raise_for_status()\n",
|
||||
" response_data = \"\"\n",
|
||||
" for line in r.iter_lines(decode_unicode=True):\n",
|
||||
" if not line:\n",
|
||||
" break\n",
|
||||
" continue\n",
|
||||
" response_json = json.loads(line)\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
" if \"message\" in response_json:\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
"\n",
|
||||
" return response_data\n",
|
||||
"\n",
|
||||
@@ -587,7 +580,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -231,23 +231,21 @@
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import urllib\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def download_and_load_file(file_path, url):\n",
|
||||
"\n",
|
||||
" if not os.path.exists(file_path):\n",
|
||||
" with urllib.request.urlopen(url) as response:\n",
|
||||
" text_data = response.read().decode(\"utf-8\")\n",
|
||||
" response = requests.get(url, timeout=30)\n",
|
||||
" response.raise_for_status()\n",
|
||||
" text_data = response.text\n",
|
||||
" with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
|
||||
" file.write(text_data)\n",
|
||||
" else:\n",
|
||||
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" text_data = file.read()\n",
|
||||
"\n",
|
||||
" with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
|
||||
" data = json.load(file)\n",
|
||||
"\n",
|
||||
" data = json.loads(text_data)\n",
|
||||
" return data\n",
|
||||
"\n",
|
||||
"\n",
|
||||
|
||||
@@ -194,8 +194,8 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import json\n",
|
||||
"import requests\n",
|
||||
"\n",
|
||||
"def query_model(prompt, model=\"llama3\", url=\"http://localhost:11434/api/chat\", role=\"user\"):\n",
|
||||
" # Create the data payload as a dictionary\n",
|
||||
@@ -209,25 +209,21 @@
|
||||
" ]\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Convert the dictionary to a JSON formatted string and encode it to bytes\n",
|
||||
" payload = json.dumps(data).encode(\"utf-8\")\n",
|
||||
"\n",
|
||||
" # Create a request object, setting the method to POST and adding necessary headers\n",
|
||||
" request = urllib.request.Request(url, data=payload, method=\"POST\")\n",
|
||||
" request.add_header(\"Content-Type\", \"application/json\")\n",
|
||||
"\n",
|
||||
" # Send the request and capture the response\n",
|
||||
" response_data = \"\"\n",
|
||||
" with urllib.request.urlopen(request) as response:\n",
|
||||
" # Read and decode the response\n",
|
||||
" while True:\n",
|
||||
" line = response.readline().decode(\"utf-8\")\n",
|
||||
" # Send the POST request\n",
|
||||
" with requests.post(url, json=data, stream=True, timeout=30) as r:\n",
|
||||
" r.raise_for_status()\n",
|
||||
" response_data = \"\"\n",
|
||||
" for line in r.iter_lines(decode_unicode=True):\n",
|
||||
" if not line:\n",
|
||||
" break\n",
|
||||
" continue\n",
|
||||
" response_json = json.loads(line)\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
" if \"message\" in response_json:\n",
|
||||
" response_data += response_json[\"message\"][\"content\"]\n",
|
||||
"\n",
|
||||
" return response_data"
|
||||
" return response_data\n",
|
||||
"\n",
|
||||
"result = query_model(\"What do Llamas eat?\")\n",
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -498,7 +494,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.16"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Reference in New Issue
Block a user