Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@
)


# Configure logging to work with tqdm progress bars
setup_logging_for_tqdm()


def parse_args():
"""
Parse command-line arguments for inference configuration.
Expand Down Expand Up @@ -102,7 +98,7 @@ def parse_args():
help="Include symmetry mate atoms as protein nodes",
)
p.add_argument(
"--geometry_cache",
"--geometry_cache_name",
type=str,
default=None,
help="Subdirectory name within processed_dir specifying which water coordinate set to use. "
Expand Down Expand Up @@ -133,7 +129,7 @@ def parse_args():
help="Number of integration steps (default: 100)",
)
p.add_argument(
"--use_sc",
"--use_self_cond",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these new names reflected in the Dockerfile?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docker file only exposes the input and output arguments, everything else is taken in as arguments by the use (and hydra configs soon)

action="store_true",
help="Use self-conditioning during integration",
Comment thread
vratins marked this conversation as resolved.
)
Expand All @@ -157,6 +153,9 @@ def parse_args():
help="Device to run inference on (default: cuda)",
)

p.add_argument("--log_level", type=str, default="INFO")
p.add_argument("--log_file", type=str, default=None)

p.add_argument(
"--batch_size",
type=int,
Expand Down Expand Up @@ -313,7 +312,7 @@ def run_inference_batch(
num_steps: int,
use_sc: bool,
device: str,
water_ratio: float = None,
water_ratio: float | None = None,
) -> list:
"""
Run inference on a batch of graphs.
Expand Down Expand Up @@ -402,8 +401,8 @@ def save_plot(

def main():
"""Run inference pipeline on a list of PDB structures."""
setup_logging_for_tqdm()
args = parse_args()
setup_logging_for_tqdm(level=args.log_level, log_file=args.log_file)

# setup paths
run_dir = Path(args.run_dir)
Expand Down Expand Up @@ -441,9 +440,12 @@ def main():
include_mates = args.include_mates or config.get("include_mates", False)
encoder_type = config.get("encoder_type", "gvp")

# Use --geometry_cache if provided, otherwise use config's geometry_cache_name
geometry_cache_name = args.geometry_cache or config.get(
"geometry_cache_name", "geometry"
# Use --geometry_cache_name if provided, otherwise use config's geometry_cache_name.
# Fall back to old "geometry_cache" key for backward compat with pre-rename run configs.
geometry_cache_name = (
args.geometry_cache_name
or config.get("geometry_cache_name")
or config.get("geometry_cache", "geometry")
)

# Extract dataset filter config from training config for consistency
Expand All @@ -465,7 +467,7 @@ def main():

# run inference
logger.info(f"Running inference with method={args.method}, steps={args.num_steps}")
logger.info(f"Self-conditioning: {args.use_sc}")
logger.info(f"Self-conditioning: {args.use_self_cond}")
logger.info(f"Threshold for metrics: {args.threshold}Å")
logger.info(f"Batch size: {args.batch_size}")

Expand Down Expand Up @@ -515,7 +517,7 @@ def main():
batch_graphs,
method=args.method,
num_steps=args.num_steps,
use_sc=args.use_sc,
use_sc=args.use_self_cond,
device=args.device,
water_ratio=args.water_ratio,
)
Expand Down Expand Up @@ -616,11 +618,11 @@ def main():
"checkpoint": args.checkpoint,
"method": args.method,
"num_steps": args.num_steps,
"use_sc": args.use_sc,
"use_self_cond": args.use_self_cond,
"threshold": args.threshold,
"include_mates": include_mates,
"water_ratio": args.water_ratio,
"geometry_cache": geometry_cache_name,
"geometry_cache_name": geometry_cache_name,
},
},
f,
Expand Down
16 changes: 9 additions & 7 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -295,7 +296,7 @@ def parse_args():
p.add_argument("--save_every", type=int, default=10)
p.add_argument("--eval_every", type=int, default=5)
p.add_argument("--n_eval_samples", type=int, default=3)
p.add_argument("--rk4_steps", type=int, default=100)
p.add_argument("--num_steps", type=int, default=100)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vratins please add an issue to add help strings to each argument.

Copy link
Copy Markdown
Contributor Author

@vratins vratins Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, is this worth an issue given that we will be switching to hydra configs as in #70

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, probably not the same issue, but the config items need to be documented somewhere.

p.add_argument(
Comment thread
vratins marked this conversation as resolved.
"--save_gifs", action="store_true", help="Save trajectory GIFs during eval"
)
Expand Down Expand Up @@ -578,7 +579,7 @@ def run_eval_sampling(

out = flow_matcher.rk4_integrate(
graph,
num_steps=args.rk4_steps,
num_steps=args.num_steps,
use_sc=args.use_self_cond,
device=device,
return_trajectory=True,
Expand Down Expand Up @@ -675,9 +676,10 @@ def train_epoch(
accumulation_steps=args.grad_accum_steps,
)

if metrics["per_sample_info"] is not None:
per_sample_losses = metrics["per_sample_info"]["losses"].cpu()
num_graphs = metrics["per_sample_info"]["num_graphs"]
per_sample_info = metrics["per_sample_info"]
if per_sample_info is not None and isinstance(per_sample_info, dict):
per_sample_losses = per_sample_info["losses"].cpu()
num_graphs = per_sample_info["num_graphs"]

if hasattr(batch, "pdb_id"):
pdb_ids = (
Expand All @@ -693,8 +695,8 @@ def train_epoch(
logger.warning("=" * 60)

processed_batches += 1
total_loss += metrics["loss"]
total_rmsd += metrics["rmsd"]
total_loss += cast(float, metrics["loss"])
total_rmsd += cast(float, metrics["rmsd"])

# Step optimizer every grad_accum_steps
if (step + 1) % args.grad_accum_steps == 0:
Expand Down
Loading