diff --git a/src/boltz/main.py b/src/boltz/main.py index 4a3750fec..aef771d85 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -12,6 +12,7 @@ from typing import Literal, Optional import click +import requests import torch from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.strategies import DDPStrategy @@ -194,6 +195,27 @@ def download_boltz1(cache: Path) -> None: continue +def download_file_with_progress(url: str, dest_path: str) -> None: + """Download a file from url to dest_path with proper redirect handling. + + Uses requests instead of urllib to properly handle HuggingFace redirects. + + Parameters + ---------- + url : str + The URL to download from. + dest_path : str + The destination file path. + """ + response = requests.get(url, stream=True, allow_redirects=True, timeout=300) + response.raise_for_status() + + with open(dest_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + @rank_zero_only def download_boltz2(cache: Path) -> None: """Download all the required data. @@ -213,7 +235,7 @@ def download_boltz2(cache: Path) -> None: "This may take a bit of time. You may change the cache directory " "with the --cache flag." ) - urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310 + download_file_with_progress(MOL_URL, str(tar_mols)) if not mols.exists(): click.echo( f"Extracting the CCD data to {mols}. " @@ -232,7 +254,7 @@ def download_boltz2(cache: Path) -> None: ) for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK): try: - urllib.request.urlretrieve(url, str(model)) # noqa: S310 + download_file_with_progress(url, str(model)) break except Exception as e: # noqa: BLE001 if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1: @@ -249,7 +271,7 @@ def download_boltz2(cache: Path) -> None: ) for i, url in enumerate(BOLTZ2_AFFINITY_URL_WITH_FALLBACK): try: - urllib.request.urlretrieve(url, str(affinity_model)) # noqa: S310 + download_file_with_progress(url, str(affinity_model)) break except Exception as e: # noqa: BLE001 if i == len(BOLTZ2_AFFINITY_URL_WITH_FALLBACK) - 1: