Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Literal, Optional

import click
import requests
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.strategies import DDPStrategy
Expand Down Expand Up @@ -194,6 +195,27 @@ def download_boltz1(cache: Path) -> None:
continue


def download_file_with_progress(url: str, dest_path: str) -> None:
"""Download a file from url to dest_path with proper redirect handling.

Uses requests instead of urllib to properly handle HuggingFace redirects.

Parameters
----------
url : str
The URL to download from.
dest_path : str
The destination file path.
"""
response = requests.get(url, stream=True, allow_redirects=True, timeout=300)
response.raise_for_status()

with open(dest_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)


@rank_zero_only
def download_boltz2(cache: Path) -> None:
"""Download all the required data.
Expand All @@ -213,7 +235,7 @@ def download_boltz2(cache: Path) -> None:
"This may take a bit of time. You may change the cache directory "
"with the --cache flag."
)
urllib.request.urlretrieve(MOL_URL, str(tar_mols)) # noqa: S310
download_file_with_progress(MOL_URL, str(tar_mols))
if not mols.exists():
click.echo(
f"Extracting the CCD data to {mols}. "
Expand All @@ -232,7 +254,7 @@ def download_boltz2(cache: Path) -> None:
)
for i, url in enumerate(BOLTZ2_URL_WITH_FALLBACK):
try:
urllib.request.urlretrieve(url, str(model)) # noqa: S310
download_file_with_progress(url, str(model))
break
except Exception as e: # noqa: BLE001
if i == len(BOLTZ2_URL_WITH_FALLBACK) - 1:
Expand All @@ -249,7 +271,7 @@ def download_boltz2(cache: Path) -> None:
)
for i, url in enumerate(BOLTZ2_AFFINITY_URL_WITH_FALLBACK):
try:
urllib.request.urlretrieve(url, str(affinity_model)) # noqa: S310
download_file_with_progress(url, str(affinity_model))
break
except Exception as e: # noqa: BLE001
if i == len(BOLTZ2_AFFINITY_URL_WITH_FALLBACK) - 1:
Expand Down