From 9ce7b05a59e8ad3b47c57a7ac40e9aa6addac270 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Tue, 21 Oct 2025 11:45:41 +0000 Subject: [PATCH 1/3] Add more fixes --- src/boltz/data/module/inference.py | 5 ++++- src/boltz/data/module/inferencev2.py | 4 +++- src/boltz/main.py | 23 +++++++++++++++++++---- tests/test_kernels.py | 2 +- tests/test_regression.py | 6 +++--- 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/src/boltz/data/module/inference.py b/src/boltz/data/module/inference.py index b09a6afc8..5762ec42a 100644 --- a/src/boltz/data/module/inference.py +++ b/src/boltz/data/module/inference.py @@ -232,6 +232,7 @@ def __init__( target_dir: Path, msa_dir: Path, num_workers: int, + inference_batch_size: int, constraints_dir: Optional[Path] = None, ) -> None: """Initialize the DataModule. @@ -248,6 +249,8 @@ def __init__( self.target_dir = target_dir self.msa_dir = msa_dir self.constraints_dir = constraints_dir + self.inference_batch_size = inference_batch_size + def predict_dataloader(self) -> DataLoader: """Get the training dataloader. @@ -266,7 +269,7 @@ def predict_dataloader(self) -> DataLoader: ) return DataLoader( dataset, - batch_size=1, + batch_size=self.inference_batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, diff --git a/src/boltz/data/module/inferencev2.py b/src/boltz/data/module/inferencev2.py index 590297d26..762e94534 100644 --- a/src/boltz/data/module/inferencev2.py +++ b/src/boltz/data/module/inferencev2.py @@ -324,6 +324,7 @@ def __init__( msa_dir: Path, mol_dir: Path, num_workers: int, + inference_batch_size: int, constraints_dir: Optional[Path] = None, template_dir: Optional[Path] = None, extra_mols_dir: Optional[Path] = None, @@ -365,6 +366,7 @@ def __init__( self.extra_mols_dir = extra_mols_dir self.override_method = override_method self.affinity = affinity + self.inference_batch_size = inference_batch_size def predict_dataloader(self) -> DataLoader: """Get the training dataloader. @@ -388,7 +390,7 @@ def predict_dataloader(self) -> DataLoader: ) return DataLoader( dataset, - batch_size=1, + batch_size=self.inference_batch_size, num_workers=self.num_workers, pin_memory=True, shuffle=False, diff --git a/src/boltz/main.py b/src/boltz/main.py index 4a3750fec..4c4322e57 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -522,7 +522,7 @@ def compute_msa( f.write("\n".join(csv_str)) -def process_input( # noqa: C901, PLR0912, PLR0915, D103 +def process_single_input( # noqa: C901, PLR0912, PLR0915, D103 path: Path, ccd: dict, msa_dir: Path, @@ -662,7 +662,7 @@ def process_input( # noqa: C901, PLR0912, PLR0915, D103 @rank_zero_only -def process_inputs( +def process_multiple_inputs( data: list[Path], out_dir: Path, ccd_path: Path, @@ -770,7 +770,7 @@ def process_inputs( # Create partial function process_input_partial = partial( - process_input, + process_single_input, ccd=ccd, msa_dir=msa_dir, mol_dir=mol_dir, @@ -1039,6 +1039,13 @@ def cli() -> None: is_flag=True, help=" to dump the s and z embeddings into a npz file. Default is False.", ) +@click.option( + "--inference_batch_size", + is_flag=True, + type=int, + default=1, + help="Set custom batch size in DataLoaders downstream.", +) def predict( # noqa: C901, PLR0915, PLR0912 data: str, out_dir: str, @@ -1077,6 +1084,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 num_subsampled_msa: int = 1024, no_kernels: bool = False, write_embeddings: bool = False, + inference_batch_size: int = 2, ) -> None: """Run predictions with Boltz.""" # If cpu, write a friendly warning @@ -1089,6 +1097,10 @@ def predict( # noqa: C901, PLR0915, PLR0912 "ignore", ".*that has Tensor Cores. To properly utilize them.*" ) + if inference_batch_size < 1: + click.echo("Received inference_batch_size < 1; resetting to 1.") + inference_batch_size = 1 + # Set no grad torch.set_grad_enabled(False) @@ -1159,7 +1171,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 # Process inputs ccd_path = cache / "ccd.pkl" mol_dir = cache / "mols" - process_inputs( + process_multiple_inputs( data=data, out_dir=out_dir, ccd_path=ccd_path, @@ -1279,6 +1291,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 template_dir=processed.template_dir, extra_mols_dir=processed.extra_mols_dir, override_method=method, + inference_batch_size=inference_batch_size ) else: data_module = BoltzInferenceDataModule( @@ -1287,6 +1300,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 msa_dir=processed.msa_dir, num_workers=num_workers, constraints_dir=processed.constraints_dir, + inference_batch_size=inference_batch_size ) # Load model @@ -1367,6 +1381,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 extra_mols_dir=processed.extra_mols_dir, override_method="other", affinity=True, + inference_batch_size=inference_batch_size ) predict_affinity_args = { diff --git a/tests/test_kernels.py b/tests/test_kernels.py index b43c0a990..13568da96 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -17,7 +17,7 @@ C_Z = 128 BATCH_SIZE = 1 INFERENCE = False -SEQ_LEN = [128, 256, 384, 512, 768] +SEQ_LEN = [128, 256, 384, 512] PRECISION = torch.bfloat16 COMPILE = False device = "cuda:0" diff --git a/tests/test_regression.py b/tests/test_regression.py index 5478a47ff..94da9c98a 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -11,8 +11,8 @@ from lightning_fabric import seed_everything -from boltz.main import MODEL_URL -from boltz.model.model import Boltz1 +from boltz.main import BOLTZ1_URL_WITH_FALLBACK +from boltz.model.models.boltz1 import Boltz1 import test_utils @@ -26,7 +26,7 @@ class RegressionTester(unittest.TestCase): def setUpClass(cls): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cache = os.path.expanduser("~/.boltz") - checkpoint_url = MODEL_URL + checkpoint_url = BOLTZ1_URL_WITH_FALLBACK[0] model_name = checkpoint_url.split("/")[-1] checkpoint = os.path.join(cache, model_name) if not os.path.exists(checkpoint): From 94206cfe6501662464e6168f050b4b05cb98dd12 Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Tue, 21 Oct 2025 13:04:48 +0000 Subject: [PATCH 2/3] Add scoring function to main.py --- src/boltz/main.py | 63 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 1 deletion(-) diff --git a/src/boltz/main.py b/src/boltz/main.py index 4c4322e57..0131a882a 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -5,6 +5,7 @@ import tarfile import urllib.request import warnings +import numpy as np from dataclasses import asdict, dataclass from functools import partial from multiprocessing import Pool @@ -807,6 +808,33 @@ def process_multiple_inputs( manifest = Manifest(records) manifest.dump(out_dir / "processed" / "manifest.json") +def _estimate_record_cost(record: Record, processed_dir: Path) -> float: + """ + Approximate cost ~ L^2 * chain_factor + L: total residues across valid chains + chain_factor: 1 + 0.05*(num_chains-1) + """ + struct_path_v2 = processed_dir / "structures" / f"{record.id}.npz" + L_total = 0 + try: + arr = np.load(struct_path_v2) + # Try residues array first + if "residues" in arr: + L_total = len(arr["residues"]) + else: + # Fallback: count by chains (rough) + L_total = sum(len(c.sequence) if hasattr(c, "sequence") and c.sequence else 0 for c in record.chains) + except Exception: + # Fallback using record chains only + L_total = sum( + len(getattr(c, "sequence", "")) if getattr(c, "sequence", "") else 0 + for c in record.chains + ) + if L_total <= 0: + L_total = 1 + num_chains = max(1, len(record.chains)) + chain_factor = 1.0 + 0.05 * (num_chains - 1) + return (L_total * L_total) * chain_factor @click.group() def cli() -> None: @@ -1046,6 +1074,17 @@ def cli() -> None: default=1, help="Set custom batch size in DataLoaders downstream.", ) +@click.option( + "--auto_batch", + is_flag=True, + help="Automatically choose inference batch size based on approximate sequence costs.", +) +@click.option( + "--batch_cost_ceiling", + type=float, + default=15000000.0, + help="Approximate total cost ceiling per batch when --auto_batch is used (cost ~ L^2).", +) def predict( # noqa: C901, PLR0915, PLR0912 data: str, out_dir: str, @@ -1084,8 +1123,12 @@ def predict( # noqa: C901, PLR0915, PLR0912 num_subsampled_msa: int = 1024, no_kernels: bool = False, write_embeddings: bool = False, - inference_batch_size: int = 2, + inference_batch_size: int = 0, + auto_batch: bool = True, + batch_cost_ceiling: float = 15000000.0, ) -> None: + # (existing code above unchanged) ... + """Run predictions with Boltz.""" # If cpu, write a friendly warning if accelerator == "cpu": @@ -1198,6 +1241,24 @@ def predict( # noqa: C901, PLR0915, PLR0912 override=override, ) + if auto_batch and filtered_manifest.records: + processed_dir = out_dir / "processed" + costs = [_estimate_record_cost(r, processed_dir) for r in filtered_manifest.records] + avg_cost = float(np.mean(costs)) + candidate = int(batch_cost_ceiling // max(avg_cost, 1.0)) + print(f'{candidate = }') + # If user manually set inference_batch_size > 1, respect manual unless candidate smaller (OOM guard) + if inference_batch_size == 1 or candidate < inference_batch_size: + inference_batch_size = candidate + print(f'{candidate = }') + print(f'{inference_batch_size = }') + click.echo( + f"[auto_batch] avg_cost={avg_cost:.1f}, cost_ceiling={batch_cost_ceiling:.1f} -> " + f"inference_batch_size={inference_batch_size}" + ) + elif auto_batch and not filtered_manifest.records: + click.echo("[auto_batch] No records to batch; skipping.") + # Load processed data processed_dir = out_dir / "processed" processed = BoltzProcessedInput( From b299fc0dbe1f9a1ba5427cbee8ba71938c85a49a Mon Sep 17 00:00:00 2001 From: EC2 Default User Date: Tue, 21 Oct 2025 13:45:22 +0000 Subject: [PATCH 3/3] Add code --- src/boltz/data/module/inferencev2.py | 1 + src/boltz/main.py | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/boltz/data/module/inferencev2.py b/src/boltz/data/module/inferencev2.py index 762e94534..bca40003c 100644 --- a/src/boltz/data/module/inferencev2.py +++ b/src/boltz/data/module/inferencev2.py @@ -388,6 +388,7 @@ def predict_dataloader(self) -> DataLoader: override_method=self.override_method, affinity=self.affinity, ) + print(f'In Boltz2Model {self.inference_batch_size = }') return DataLoader( dataset, batch_size=self.inference_batch_size, diff --git a/src/boltz/main.py b/src/boltz/main.py index 0131a882a..62736fde7 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -1123,7 +1123,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 num_subsampled_msa: int = 1024, no_kernels: bool = False, write_embeddings: bool = False, - inference_batch_size: int = 0, + inference_batch_size: int = 100, auto_batch: bool = True, batch_cost_ceiling: float = 15000000.0, ) -> None: @@ -1142,7 +1142,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 if inference_batch_size < 1: click.echo("Received inference_batch_size < 1; resetting to 1.") - inference_batch_size = 1 + inference_batch_size = 10000 # Set no grad torch.set_grad_enabled(False) @@ -1249,7 +1249,7 @@ def predict( # noqa: C901, PLR0915, PLR0912 print(f'{candidate = }') # If user manually set inference_batch_size > 1, respect manual unless candidate smaller (OOM guard) if inference_batch_size == 1 or candidate < inference_batch_size: - inference_batch_size = candidate + inference_batch_size = 100 print(f'{candidate = }') print(f'{inference_batch_size = }') click.echo( @@ -1259,6 +1259,8 @@ def predict( # noqa: C901, PLR0915, PLR0912 elif auto_batch and not filtered_manifest.records: click.echo("[auto_batch] No records to batch; skipping.") + print(f'{inference_batch_size = }') + inference_batch_size = 100 # Load processed data processed_dir = out_dir / "processed" processed = BoltzProcessedInput(