Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/boltz/data/module/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
target_dir: Path,
msa_dir: Path,
num_workers: int,
inference_batch_size: int,
constraints_dir: Optional[Path] = None,
) -> None:
"""Initialize the DataModule.
Expand All @@ -248,6 +249,8 @@ def __init__(
self.target_dir = target_dir
self.msa_dir = msa_dir
self.constraints_dir = constraints_dir
self.inference_batch_size = inference_batch_size


def predict_dataloader(self) -> DataLoader:
"""Get the training dataloader.
Expand All @@ -266,7 +269,7 @@ def predict_dataloader(self) -> DataLoader:
)
return DataLoader(
dataset,
batch_size=1,
batch_size=self.inference_batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
Expand Down
5 changes: 4 additions & 1 deletion src/boltz/data/module/inferencev2.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def __init__(
msa_dir: Path,
mol_dir: Path,
num_workers: int,
inference_batch_size: int,
constraints_dir: Optional[Path] = None,
template_dir: Optional[Path] = None,
extra_mols_dir: Optional[Path] = None,
Expand Down Expand Up @@ -365,6 +366,7 @@ def __init__(
self.extra_mols_dir = extra_mols_dir
self.override_method = override_method
self.affinity = affinity
self.inference_batch_size = inference_batch_size

def predict_dataloader(self) -> DataLoader:
"""Get the training dataloader.
Expand All @@ -386,9 +388,10 @@ def predict_dataloader(self) -> DataLoader:
override_method=self.override_method,
affinity=self.affinity,
)
print(f'In Boltz2Model {self.inference_batch_size = }')
return DataLoader(
dataset,
batch_size=1,
batch_size=self.inference_batch_size,
num_workers=self.num_workers,
pin_memory=True,
shuffle=False,
Expand Down
86 changes: 82 additions & 4 deletions src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tarfile
import urllib.request
import warnings
import numpy as np
from dataclasses import asdict, dataclass
from functools import partial
from multiprocessing import Pool
Expand Down Expand Up @@ -522,7 +523,7 @@ def compute_msa(
f.write("\n".join(csv_str))


def process_input( # noqa: C901, PLR0912, PLR0915, D103
def process_single_input( # noqa: C901, PLR0912, PLR0915, D103
path: Path,
ccd: dict,
msa_dir: Path,
Expand Down Expand Up @@ -662,7 +663,7 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103


@rank_zero_only
def process_inputs(
def process_multiple_inputs(
data: list[Path],
out_dir: Path,
ccd_path: Path,
Expand Down Expand Up @@ -770,7 +771,7 @@ def process_inputs(

# Create partial function
process_input_partial = partial(
process_input,
process_single_input,
ccd=ccd,
msa_dir=msa_dir,
mol_dir=mol_dir,
Expand Down Expand Up @@ -807,6 +808,33 @@ def process_inputs(
manifest = Manifest(records)
manifest.dump(out_dir / "processed" / "manifest.json")

def _estimate_record_cost(record: Record, processed_dir: Path) -> float:
"""
Approximate cost ~ L^2 * chain_factor
L: total residues across valid chains
chain_factor: 1 + 0.05*(num_chains-1)
"""
struct_path_v2 = processed_dir / "structures" / f"{record.id}.npz"
L_total = 0
try:
arr = np.load(struct_path_v2)
# Try residues array first
if "residues" in arr:
L_total = len(arr["residues"])
else:
# Fallback: count by chains (rough)
L_total = sum(len(c.sequence) if hasattr(c, "sequence") and c.sequence else 0 for c in record.chains)
except Exception:
# Fallback using record chains only
L_total = sum(
len(getattr(c, "sequence", "")) if getattr(c, "sequence", "") else 0
for c in record.chains
)
if L_total <= 0:
L_total = 1
num_chains = max(1, len(record.chains))
chain_factor = 1.0 + 0.05 * (num_chains - 1)
return (L_total * L_total) * chain_factor

@click.group()
def cli() -> None:
Expand Down Expand Up @@ -1039,6 +1067,24 @@ def cli() -> None:
is_flag=True,
help=" to dump the s and z embeddings into a npz file. Default is False.",
)
@click.option(
"--inference_batch_size",
is_flag=True,
type=int,
default=1,
help="Set custom batch size in DataLoaders downstream.",
)
@click.option(
"--auto_batch",
is_flag=True,
help="Automatically choose inference batch size based on approximate sequence costs.",
)
@click.option(
"--batch_cost_ceiling",
type=float,
default=15000000.0,
help="Approximate total cost ceiling per batch when --auto_batch is used (cost ~ L^2).",
)
def predict( # noqa: C901, PLR0915, PLR0912
data: str,
out_dir: str,
Expand Down Expand Up @@ -1077,7 +1123,12 @@ def predict( # noqa: C901, PLR0915, PLR0912
num_subsampled_msa: int = 1024,
no_kernels: bool = False,
write_embeddings: bool = False,
inference_batch_size: int = 100,
auto_batch: bool = True,
batch_cost_ceiling: float = 15000000.0,
) -> None:
# (existing code above unchanged) ...

"""Run predictions with Boltz."""
# If cpu, write a friendly warning
if accelerator == "cpu":
Expand All @@ -1089,6 +1140,10 @@ def predict( # noqa: C901, PLR0915, PLR0912
"ignore", ".*that has Tensor Cores. To properly utilize them.*"
)

if inference_batch_size < 1:
click.echo("Received inference_batch_size < 1; resetting to 1.")
inference_batch_size = 10000

# Set no grad
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -1159,7 +1214,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
# Process inputs
ccd_path = cache / "ccd.pkl"
mol_dir = cache / "mols"
process_inputs(
process_multiple_inputs(
data=data,
out_dir=out_dir,
ccd_path=ccd_path,
Expand All @@ -1186,6 +1241,26 @@ def predict( # noqa: C901, PLR0915, PLR0912
override=override,
)

if auto_batch and filtered_manifest.records:
processed_dir = out_dir / "processed"
costs = [_estimate_record_cost(r, processed_dir) for r in filtered_manifest.records]
avg_cost = float(np.mean(costs))
candidate = int(batch_cost_ceiling // max(avg_cost, 1.0))
print(f'{candidate = }')
# If user manually set inference_batch_size > 1, respect manual unless candidate smaller (OOM guard)
if inference_batch_size == 1 or candidate < inference_batch_size:
inference_batch_size = 100
print(f'{candidate = }')
print(f'{inference_batch_size = }')
click.echo(
f"[auto_batch] avg_cost={avg_cost:.1f}, cost_ceiling={batch_cost_ceiling:.1f} -> "
f"inference_batch_size={inference_batch_size}"
)
elif auto_batch and not filtered_manifest.records:
click.echo("[auto_batch] No records to batch; skipping.")

print(f'{inference_batch_size = }')
inference_batch_size = 100
# Load processed data
processed_dir = out_dir / "processed"
processed = BoltzProcessedInput(
Expand Down Expand Up @@ -1279,6 +1354,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
template_dir=processed.template_dir,
extra_mols_dir=processed.extra_mols_dir,
override_method=method,
inference_batch_size=inference_batch_size
)
else:
data_module = BoltzInferenceDataModule(
Expand All @@ -1287,6 +1363,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
msa_dir=processed.msa_dir,
num_workers=num_workers,
constraints_dir=processed.constraints_dir,
inference_batch_size=inference_batch_size
)

# Load model
Expand Down Expand Up @@ -1367,6 +1444,7 @@ def predict( # noqa: C901, PLR0915, PLR0912
extra_mols_dir=processed.extra_mols_dir,
override_method="other",
affinity=True,
inference_batch_size=inference_batch_size
)

predict_affinity_args = {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
C_Z = 128
BATCH_SIZE = 1
INFERENCE = False
SEQ_LEN = [128, 256, 384, 512, 768]
SEQ_LEN = [128, 256, 384, 512]
PRECISION = torch.bfloat16
COMPILE = False
device = "cuda:0"
Expand Down
6 changes: 3 additions & 3 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

from lightning_fabric import seed_everything

from boltz.main import MODEL_URL
from boltz.model.model import Boltz1
from boltz.main import BOLTZ1_URL_WITH_FALLBACK
from boltz.model.models.boltz1 import Boltz1

import test_utils

Expand All @@ -26,7 +26,7 @@ class RegressionTester(unittest.TestCase):
def setUpClass(cls):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cache = os.path.expanduser("~/.boltz")
checkpoint_url = MODEL_URL
checkpoint_url = BOLTZ1_URL_WITH_FALLBACK[0]
model_name = checkpoint_url.split("/")[-1]
checkpoint = os.path.join(cache, model_name)
if not os.path.exists(checkpoint):
Expand Down