From 0da6fab22acf9d6dc20ff0420f5e8524cce3fd0e Mon Sep 17 00:00:00 2001 From: Vratin Srivastava Date: Thu, 19 Mar 2026 12:07:12 -0500 Subject: [PATCH 1/6] making the cli args for train and inference consistent --- scripts/inference.py | 33 ++++++++++++++++----------------- scripts/train.py | 18 +++++++++--------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 324ed4f..8585818 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -40,10 +40,6 @@ ) -# Configure logging to work with tqdm progress bars -setup_logging_for_tqdm() - - def parse_args(): """ Parse command-line arguments for inference configuration. @@ -102,7 +98,7 @@ def parse_args(): help="Include symmetry mate atoms as protein nodes", ) p.add_argument( - "--geometry_cache", + "--geometry_cache_name", type=str, default=None, help="Subdirectory name within processed_dir specifying which water coordinate set to use. " @@ -133,7 +129,7 @@ def parse_args(): help="Number of integration steps (default: 100)", ) p.add_argument( - "--use_sc", + "--use_self_cond", action="store_true", help="Use self-conditioning during integration", ) @@ -157,6 +153,9 @@ def parse_args(): help="Device to run inference on (default: cuda)", ) + p.add_argument("--log_level", type=str, default="INFO") + p.add_argument("--log_file", type=str, default=None) + p.add_argument( "--batch_size", type=int, @@ -257,8 +256,8 @@ def build_model_from_config(config: dict, device: torch.device) -> nn.Module: edge_scalar_dim=config.get("edge_scalar_dim") or NUM_RBF, layers=config.get("flow_layers") or 3, drop_rate=config.get("drop_rate", 0.1), - n_message_gvps=config.get("n_message_gvps", 2), - n_update_gvps=config.get("n_update_gvps", 2), + n_message_gvps=config.get("n_message_gvps", 2), # ty: ignore[unknown-argument] + n_update_gvps=config.get("n_update_gvps", 2), # ty: ignore[unknown-argument] k_pw=config.get("k_pw") or 16, k_ww=config.get("k_ww") or 16, ).to(device) @@ -298,7 +297,7 @@ def run_inference_batch( num_steps: int, use_sc: bool, device: str, - water_ratio: float = None, + water_ratio: float | None = None, ) -> list: """ Run inference on a batch of graphs. @@ -387,8 +386,8 @@ def save_plot( def main(): """Run inference pipeline on a list of PDB structures.""" - setup_logging_for_tqdm() args = parse_args() + setup_logging_for_tqdm(level=args.log_level, log_file=args.log_file) # setup paths run_dir = Path(args.run_dir) @@ -426,8 +425,8 @@ def main(): include_mates = args.include_mates or config.get("include_mates", False) encoder_type = config.get("encoder_type", "gvp") - # Use --geometry_cache if provided, otherwise use config's geometry_cache_name - geometry_cache_name = args.geometry_cache or config.get( + # Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name + geometry_cache_name = args.geometry_cache_name or config.get( "geometry_cache_name", "geometry" ) @@ -435,9 +434,9 @@ def main(): pdb_list_file=args.pdb_list, processed_dir=args.processed_dir, base_pdb_dir=args.base_pdb_dir, - encoder_type=encoder_type, + encoder_type=encoder_type, # ty: ignore[unknown-argument] include_mates=include_mates, - geometry_cache_name=geometry_cache_name, + geometry_cache_name=geometry_cache_name, # ty: ignore[unknown-argument] preprocess=True, ) @@ -446,7 +445,7 @@ def main(): # run inference logger.info(f"Running inference with method={args.method}, steps={args.num_steps}") - logger.info(f"Self-conditioning: {args.use_sc}") + logger.info(f"Self-conditioning: {args.use_self_cond}") logger.info(f"Threshold for metrics: {args.threshold}Å") logger.info(f"Batch size: {args.batch_size}") @@ -496,7 +495,7 @@ def main(): batch_graphs, method=args.method, num_steps=args.num_steps, - use_sc=args.use_sc, + use_sc=args.use_self_cond, device=args.device, water_ratio=args.water_ratio, ) @@ -597,7 +596,7 @@ def main(): "checkpoint": args.checkpoint, "method": args.method, "num_steps": args.num_steps, - "use_sc": args.use_sc, + "use_sc": args.use_self_cond, "threshold": args.threshold, "include_mates": include_mates, "water_ratio": args.water_ratio, diff --git a/scripts/train.py b/scripts/train.py index fe50412..ce39d14 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -301,7 +301,7 @@ def parse_args(): p.add_argument("--save_every", type=int, default=10) p.add_argument("--eval_every", type=int, default=5) p.add_argument("--n_eval_samples", type=int, default=3) - p.add_argument("--rk4_steps", type=int, default=100) + p.add_argument("--num_steps", type=int, default=100) p.add_argument( "--save_gifs", action="store_true", help="Save trajectory GIFs during eval" ) @@ -543,8 +543,8 @@ def build_model( encoder=encoder, hidden_dims=(args.hidden_s, args.hidden_v), layers=args.flow_layers, - n_message_gvps=args.n_message_gvps, - n_update_gvps=args.n_update_gvps, + n_message_gvps=args.n_message_gvps, # ty: ignore[unknown-argument] + n_update_gvps=args.n_update_gvps, # ty: ignore[unknown-argument] drop_rate=args.drop_rate, k_pw=args.k_pw, k_ww=args.k_ww, @@ -572,7 +572,7 @@ def run_eval_sampling( out = flow_matcher.rk4_integrate( graph, - num_steps=args.rk4_steps, + num_steps=args.num_steps, use_sc=args.use_self_cond, device=device, return_trajectory=True, @@ -666,12 +666,12 @@ def train_epoch( metrics = flow_matcher.training_step( batch, use_self_conditioning=args.use_self_cond, - accumulation_steps=args.grad_accum_steps, + accumulation_steps=args.grad_accum_steps, # ty: ignore[unknown-argument] ) if metrics["per_sample_info"] is not None: - per_sample_losses = metrics["per_sample_info"]["losses"].cpu() - num_graphs = metrics["per_sample_info"]["num_graphs"] + per_sample_losses = metrics["per_sample_info"]["losses"].cpu() # ty: ignore[not-subscriptable] + num_graphs = metrics["per_sample_info"]["num_graphs"] # ty: ignore[not-subscriptable] if hasattr(batch, "pdb_id"): pdb_ids = ( @@ -687,8 +687,8 @@ def train_epoch( logger.warning("=" * 60) processed_batches += 1 - total_loss += metrics["loss"] - total_rmsd += metrics["rmsd"] + total_loss += metrics["loss"] # ty: ignore[unsupported-operator] + total_rmsd += metrics["rmsd"] # ty: ignore[unsupported-operator] # Step optimizer every grad_accum_steps if (step + 1) % args.grad_accum_steps == 0: From d6a1df49510d35b1b4f5117d4152fece07127af1 Mon Sep 17 00:00:00 2001 From: Vratin Srivastava Date: Thu, 19 Mar 2026 12:07:12 -0500 Subject: [PATCH 2/6] making the cli args for train and inference consistent --- scripts/inference.py | 33 ++++++++++++++++----------------- scripts/train.py | 18 +++++++++--------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 324ed4f..8585818 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -40,10 +40,6 @@ ) -# Configure logging to work with tqdm progress bars -setup_logging_for_tqdm() - - def parse_args(): """ Parse command-line arguments for inference configuration. @@ -102,7 +98,7 @@ def parse_args(): help="Include symmetry mate atoms as protein nodes", ) p.add_argument( - "--geometry_cache", + "--geometry_cache_name", type=str, default=None, help="Subdirectory name within processed_dir specifying which water coordinate set to use. " @@ -133,7 +129,7 @@ def parse_args(): help="Number of integration steps (default: 100)", ) p.add_argument( - "--use_sc", + "--use_self_cond", action="store_true", help="Use self-conditioning during integration", ) @@ -157,6 +153,9 @@ def parse_args(): help="Device to run inference on (default: cuda)", ) + p.add_argument("--log_level", type=str, default="INFO") + p.add_argument("--log_file", type=str, default=None) + p.add_argument( "--batch_size", type=int, @@ -257,8 +256,8 @@ def build_model_from_config(config: dict, device: torch.device) -> nn.Module: edge_scalar_dim=config.get("edge_scalar_dim") or NUM_RBF, layers=config.get("flow_layers") or 3, drop_rate=config.get("drop_rate", 0.1), - n_message_gvps=config.get("n_message_gvps", 2), - n_update_gvps=config.get("n_update_gvps", 2), + n_message_gvps=config.get("n_message_gvps", 2), # ty: ignore[unknown-argument] + n_update_gvps=config.get("n_update_gvps", 2), # ty: ignore[unknown-argument] k_pw=config.get("k_pw") or 16, k_ww=config.get("k_ww") or 16, ).to(device) @@ -298,7 +297,7 @@ def run_inference_batch( num_steps: int, use_sc: bool, device: str, - water_ratio: float = None, + water_ratio: float | None = None, ) -> list: """ Run inference on a batch of graphs. @@ -387,8 +386,8 @@ def save_plot( def main(): """Run inference pipeline on a list of PDB structures.""" - setup_logging_for_tqdm() args = parse_args() + setup_logging_for_tqdm(level=args.log_level, log_file=args.log_file) # setup paths run_dir = Path(args.run_dir) @@ -426,8 +425,8 @@ def main(): include_mates = args.include_mates or config.get("include_mates", False) encoder_type = config.get("encoder_type", "gvp") - # Use --geometry_cache if provided, otherwise use config's geometry_cache_name - geometry_cache_name = args.geometry_cache or config.get( + # Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name + geometry_cache_name = args.geometry_cache_name or config.get( "geometry_cache_name", "geometry" ) @@ -435,9 +434,9 @@ def main(): pdb_list_file=args.pdb_list, processed_dir=args.processed_dir, base_pdb_dir=args.base_pdb_dir, - encoder_type=encoder_type, + encoder_type=encoder_type, # ty: ignore[unknown-argument] include_mates=include_mates, - geometry_cache_name=geometry_cache_name, + geometry_cache_name=geometry_cache_name, # ty: ignore[unknown-argument] preprocess=True, ) @@ -446,7 +445,7 @@ def main(): # run inference logger.info(f"Running inference with method={args.method}, steps={args.num_steps}") - logger.info(f"Self-conditioning: {args.use_sc}") + logger.info(f"Self-conditioning: {args.use_self_cond}") logger.info(f"Threshold for metrics: {args.threshold}Å") logger.info(f"Batch size: {args.batch_size}") @@ -496,7 +495,7 @@ def main(): batch_graphs, method=args.method, num_steps=args.num_steps, - use_sc=args.use_sc, + use_sc=args.use_self_cond, device=args.device, water_ratio=args.water_ratio, ) @@ -597,7 +596,7 @@ def main(): "checkpoint": args.checkpoint, "method": args.method, "num_steps": args.num_steps, - "use_sc": args.use_sc, + "use_sc": args.use_self_cond, "threshold": args.threshold, "include_mates": include_mates, "water_ratio": args.water_ratio, diff --git a/scripts/train.py b/scripts/train.py index fe50412..ce39d14 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -301,7 +301,7 @@ def parse_args(): p.add_argument("--save_every", type=int, default=10) p.add_argument("--eval_every", type=int, default=5) p.add_argument("--n_eval_samples", type=int, default=3) - p.add_argument("--rk4_steps", type=int, default=100) + p.add_argument("--num_steps", type=int, default=100) p.add_argument( "--save_gifs", action="store_true", help="Save trajectory GIFs during eval" ) @@ -543,8 +543,8 @@ def build_model( encoder=encoder, hidden_dims=(args.hidden_s, args.hidden_v), layers=args.flow_layers, - n_message_gvps=args.n_message_gvps, - n_update_gvps=args.n_update_gvps, + n_message_gvps=args.n_message_gvps, # ty: ignore[unknown-argument] + n_update_gvps=args.n_update_gvps, # ty: ignore[unknown-argument] drop_rate=args.drop_rate, k_pw=args.k_pw, k_ww=args.k_ww, @@ -572,7 +572,7 @@ def run_eval_sampling( out = flow_matcher.rk4_integrate( graph, - num_steps=args.rk4_steps, + num_steps=args.num_steps, use_sc=args.use_self_cond, device=device, return_trajectory=True, @@ -666,12 +666,12 @@ def train_epoch( metrics = flow_matcher.training_step( batch, use_self_conditioning=args.use_self_cond, - accumulation_steps=args.grad_accum_steps, + accumulation_steps=args.grad_accum_steps, # ty: ignore[unknown-argument] ) if metrics["per_sample_info"] is not None: - per_sample_losses = metrics["per_sample_info"]["losses"].cpu() - num_graphs = metrics["per_sample_info"]["num_graphs"] + per_sample_losses = metrics["per_sample_info"]["losses"].cpu() # ty: ignore[not-subscriptable] + num_graphs = metrics["per_sample_info"]["num_graphs"] # ty: ignore[not-subscriptable] if hasattr(batch, "pdb_id"): pdb_ids = ( @@ -687,8 +687,8 @@ def train_epoch( logger.warning("=" * 60) processed_batches += 1 - total_loss += metrics["loss"] - total_rmsd += metrics["rmsd"] + total_loss += metrics["loss"] # ty: ignore[unsupported-operator] + total_rmsd += metrics["rmsd"] # ty: ignore[unsupported-operator] # Step optimizer every grad_accum_steps if (step + 1) % args.grad_accum_steps == 0: From 2bb978fa6f8a2676037ac5599481efe4691c2263 Mon Sep 17 00:00:00 2001 From: Vratin Srivastava Date: Mon, 23 Mar 2026 11:51:51 -0500 Subject: [PATCH 3/6] test commit --- scripts/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 8585818..0a9294a 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -434,9 +434,9 @@ def main(): pdb_list_file=args.pdb_list, processed_dir=args.processed_dir, base_pdb_dir=args.base_pdb_dir, - encoder_type=encoder_type, # ty: ignore[unknown-argument] + encoder_type=encoder_type, include_mates=include_mates, - geometry_cache_name=geometry_cache_name, # ty: ignore[unknown-argument] + geometry_cache_name=geometry_cache_name, preprocess=True, ) From 82690bd8e2bb9b9271fc9009a3b7f6b78903b12a Mon Sep 17 00:00:00 2001 From: Vratin Srivastava Date: Mon, 23 Mar 2026 16:23:06 -0500 Subject: [PATCH 4/6] getting rig of ty ignore comments --- scripts/inference.py | 8 ++++---- scripts/train.py | 18 ++++++++++-------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 40a0628..60f541e 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -271,8 +271,8 @@ def build_model_from_config(config: dict, device: torch.device) -> nn.Module: edge_scalar_dim=config.get("edge_scalar_dim") or NUM_RBF, layers=config.get("flow_layers") or 3, drop_rate=config.get("drop_rate", 0.1), - n_message_gvps=config.get("n_message_gvps", 2), # ty: ignore[unknown-argument] - n_update_gvps=config.get("n_update_gvps", 2), # ty: ignore[unknown-argument] + n_message_gvps=config.get("n_message_gvps", 2), + n_update_gvps=config.get("n_update_gvps", 2), k_pw=config.get("k_pw") or 16, k_ww=config.get("k_ww") or 16, ).to(device) @@ -452,9 +452,9 @@ def main(): pdb_list_file=args.pdb_list, processed_dir=args.processed_dir, base_pdb_dir=args.base_pdb_dir, - encoder_type=encoder_type, # ty: ignore[unknown-argument] + encoder_type=encoder_type, include_mates=include_mates, - geometry_cache_name=geometry_cache_name, # ty: ignore[unknown-argument] + geometry_cache_name=geometry_cache_name, preprocess=True, **filter_config, ) diff --git a/scripts/train.py b/scripts/train.py index f47a431..0ac7eae 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -21,6 +21,7 @@ import argparse import json from datetime import datetime +from typing import cast from pathlib import Path import matplotlib.pyplot as plt @@ -549,8 +550,8 @@ def build_model( encoder=encoder, hidden_dims=(args.hidden_s, args.hidden_v), layers=args.flow_layers, - n_message_gvps=args.n_message_gvps, # ty: ignore[unknown-argument] - n_update_gvps=args.n_update_gvps, # ty: ignore[unknown-argument] + n_message_gvps=args.n_message_gvps, + n_update_gvps=args.n_update_gvps, drop_rate=args.drop_rate, k_pw=args.k_pw, k_ww=args.k_ww, @@ -672,12 +673,13 @@ def train_epoch( metrics = flow_matcher.training_step( batch, use_self_conditioning=args.use_self_cond, - accumulation_steps=args.grad_accum_steps, # ty: ignore[unknown-argument] + accumulation_steps=args.grad_accum_steps, ) - if metrics["per_sample_info"] is not None: - per_sample_losses = metrics["per_sample_info"]["losses"].cpu() # ty: ignore[not-subscriptable] - num_graphs = metrics["per_sample_info"]["num_graphs"] # ty: ignore[not-subscriptable] + per_sample_info = metrics["per_sample_info"] + if per_sample_info is not None and isinstance(per_sample_info, dict): + per_sample_losses = per_sample_info["losses"].cpu() + num_graphs = per_sample_info["num_graphs"] if hasattr(batch, "pdb_id"): pdb_ids = ( @@ -693,8 +695,8 @@ def train_epoch( logger.warning("=" * 60) processed_batches += 1 - total_loss += metrics["loss"] # ty: ignore[unsupported-operator] - total_rmsd += metrics["rmsd"] # ty: ignore[unsupported-operator] + total_loss += cast(float, metrics["loss"]) + total_rmsd += cast(float, metrics["rmsd"]) # Step optimizer every grad_accum_steps if (step + 1) % args.grad_accum_steps == 0: From 8bfa06081f7d654f1b10644be4133a042ce3d57b Mon Sep 17 00:00:00 2001 From: vratins <114123331+vratins@users.noreply.github.com> Date: Mon, 23 Mar 2026 21:23:23 +0000 Subject: [PATCH 5/6] Auto-commit ruff fixes [skip ci] --- scripts/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 0ac7eae..8abcbad 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -21,8 +21,8 @@ import argparse import json from datetime import datetime -from typing import cast from pathlib import Path +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -679,7 +679,7 @@ def train_epoch( per_sample_info = metrics["per_sample_info"] if per_sample_info is not None and isinstance(per_sample_info, dict): per_sample_losses = per_sample_info["losses"].cpu() - num_graphs = per_sample_info["num_graphs"] + num_graphs = per_sample_info["num_graphs"] if hasattr(batch, "pdb_id"): pdb_ids = ( @@ -696,7 +696,7 @@ def train_epoch( processed_batches += 1 total_loss += cast(float, metrics["loss"]) - total_rmsd += cast(float, metrics["rmsd"]) + total_rmsd += cast(float, metrics["rmsd"]) # Step optimizer every grad_accum_steps if (step + 1) % args.grad_accum_steps == 0: From ab2c5389e86385340063b0161934d25f25012080 Mon Sep 17 00:00:00 2001 From: Vratin Srivastava Date: Thu, 26 Mar 2026 18:10:57 -0500 Subject: [PATCH 6/6] addressing CR --- scripts/inference.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/scripts/inference.py b/scripts/inference.py index 60f541e..8027a93 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -440,9 +440,12 @@ def main(): include_mates = args.include_mates or config.get("include_mates", False) encoder_type = config.get("encoder_type", "gvp") - # Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name - geometry_cache_name = args.geometry_cache_name or config.get( - "geometry_cache_name", "geometry" + # Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name. + # Fall back to old "geometry_cache" key for backward compat with pre-rename run configs. + geometry_cache_name = ( + args.geometry_cache_name + or config.get("geometry_cache_name") + or config.get("geometry_cache", "geometry") ) # Extract dataset filter config from training config for consistency @@ -615,11 +618,11 @@ def main(): "checkpoint": args.checkpoint, "method": args.method, "num_steps": args.num_steps, - "use_sc": args.use_self_cond, + "use_self_cond": args.use_self_cond, "threshold": args.threshold, "include_mates": include_mates, "water_ratio": args.water_ratio, - "geometry_cache": geometry_cache_name, + "geometry_cache_name": geometry_cache_name, }, }, f,