diff --git a/scripts/inference.py b/scripts/inference.py index e6ae9b2..8027a93 100644 --- a/scripts/inference.py +++ b/scripts/inference.py @@ -40,10 +40,6 @@ ) -# Configure logging to work with tqdm progress bars -setup_logging_for_tqdm() - - def parse_args(): """ Parse command-line arguments for inference configuration. @@ -102,7 +98,7 @@ def parse_args(): help="Include symmetry mate atoms as protein nodes", ) p.add_argument( - "--geometry_cache", + "--geometry_cache_name", type=str, default=None, help="Subdirectory name within processed_dir specifying which water coordinate set to use. " @@ -133,7 +129,7 @@ def parse_args(): help="Number of integration steps (default: 100)", ) p.add_argument( - "--use_sc", + "--use_self_cond", action="store_true", help="Use self-conditioning during integration", ) @@ -157,6 +153,9 @@ def parse_args(): help="Device to run inference on (default: cuda)", ) + p.add_argument("--log_level", type=str, default="INFO") + p.add_argument("--log_file", type=str, default=None) + p.add_argument( "--batch_size", type=int, @@ -313,7 +312,7 @@ def run_inference_batch( num_steps: int, use_sc: bool, device: str, - water_ratio: float = None, + water_ratio: float | None = None, ) -> list: """ Run inference on a batch of graphs. @@ -402,8 +401,8 @@ def save_plot( def main(): """Run inference pipeline on a list of PDB structures.""" - setup_logging_for_tqdm() args = parse_args() + setup_logging_for_tqdm(level=args.log_level, log_file=args.log_file) # setup paths run_dir = Path(args.run_dir) @@ -441,9 +440,12 @@ def main(): include_mates = args.include_mates or config.get("include_mates", False) encoder_type = config.get("encoder_type", "gvp") - # Use --geometry_cache if provided, otherwise use config's geometry_cache_name - geometry_cache_name = args.geometry_cache or config.get( - "geometry_cache_name", "geometry" + # Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name. + # Fall back to old "geometry_cache" key for backward compat with pre-rename run configs. + geometry_cache_name = ( + args.geometry_cache_name + or config.get("geometry_cache_name") + or config.get("geometry_cache", "geometry") ) # Extract dataset filter config from training config for consistency @@ -465,7 +467,7 @@ def main(): # run inference logger.info(f"Running inference with method={args.method}, steps={args.num_steps}") - logger.info(f"Self-conditioning: {args.use_sc}") + logger.info(f"Self-conditioning: {args.use_self_cond}") logger.info(f"Threshold for metrics: {args.threshold}Å") logger.info(f"Batch size: {args.batch_size}") @@ -515,7 +517,7 @@ def main(): batch_graphs, method=args.method, num_steps=args.num_steps, - use_sc=args.use_sc, + use_sc=args.use_self_cond, device=args.device, water_ratio=args.water_ratio, ) @@ -616,11 +618,11 @@ def main(): "checkpoint": args.checkpoint, "method": args.method, "num_steps": args.num_steps, - "use_sc": args.use_sc, + "use_self_cond": args.use_self_cond, "threshold": args.threshold, "include_mates": include_mates, "water_ratio": args.water_ratio, - "geometry_cache": geometry_cache_name, + "geometry_cache_name": geometry_cache_name, }, }, f, diff --git a/scripts/train.py b/scripts/train.py index ee61acf..8abcbad 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -22,6 +22,7 @@ import json from datetime import datetime from pathlib import Path +from typing import cast import matplotlib.pyplot as plt import numpy as np @@ -295,7 +296,7 @@ def parse_args(): p.add_argument("--save_every", type=int, default=10) p.add_argument("--eval_every", type=int, default=5) p.add_argument("--n_eval_samples", type=int, default=3) - p.add_argument("--rk4_steps", type=int, default=100) + p.add_argument("--num_steps", type=int, default=100) p.add_argument( "--save_gifs", action="store_true", help="Save trajectory GIFs during eval" ) @@ -578,7 +579,7 @@ def run_eval_sampling( out = flow_matcher.rk4_integrate( graph, - num_steps=args.rk4_steps, + num_steps=args.num_steps, use_sc=args.use_self_cond, device=device, return_trajectory=True, @@ -675,9 +676,10 @@ def train_epoch( accumulation_steps=args.grad_accum_steps, ) - if metrics["per_sample_info"] is not None: - per_sample_losses = metrics["per_sample_info"]["losses"].cpu() - num_graphs = metrics["per_sample_info"]["num_graphs"] + per_sample_info = metrics["per_sample_info"] + if per_sample_info is not None and isinstance(per_sample_info, dict): + per_sample_losses = per_sample_info["losses"].cpu() + num_graphs = per_sample_info["num_graphs"] if hasattr(batch, "pdb_id"): pdb_ids = ( @@ -693,8 +695,8 @@ def train_epoch( logger.warning("=" * 60) processed_batches += 1 - total_loss += metrics["loss"] - total_rmsd += metrics["rmsd"] + total_loss += cast(float, metrics["loss"]) + total_rmsd += cast(float, metrics["rmsd"]) # Step optimizer every grad_accum_steps if (step + 1) % args.grad_accum_steps == 0: