diff --git a/pkg/llms_from_scratch/utils.py b/pkg/llms_from_scratch/utils.py index 466ca4c..174f83a 100644 --- a/pkg/llms_from_scratch/utils.py +++ b/pkg/llms_from_scratch/utils.py @@ -9,6 +9,8 @@ import ast import re import types from pathlib import Path +import urllib.request +import urllib.parse import nbformat @@ -122,3 +124,22 @@ def import_definitions_from_notebook(nb_dir_or_path, notebook_name=None, *, extr exec(src, mod.__dict__) return mod + +def download_file(url, out_dir="."): + """Simple file download utility for tests.""" + from pathlib import Path + out_dir = Path(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + filename = Path(urllib.parse.urlparse(url).path).name + dest = out_dir / filename + + if dest.exists(): + return dest + + try: + with urllib.request.urlopen(url) as response: + with open(dest, 'wb') as f: + f.write(response.read()) + return dest + except Exception as e: + raise RuntimeError(f"Failed to download {url}: {e}")