-
Notifications
You must be signed in to change notification settings - Fork 0
Making the CLI args for train and inference consistent #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
0da6fab
d6a1df4
2bb978f
5936135
9d6004e
f958c5f
82690bd
8bfa060
ab2c538
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ | |
| import json | ||
| from datetime import datetime | ||
| from pathlib import Path | ||
| from typing import cast | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
|
|
@@ -295,7 +296,7 @@ def parse_args(): | |
| p.add_argument("--save_every", type=int, default=10) | ||
| p.add_argument("--eval_every", type=int, default=5) | ||
| p.add_argument("--n_eval_samples", type=int, default=3) | ||
| p.add_argument("--rk4_steps", type=int, default=100) | ||
| p.add_argument("--num_steps", type=int, default=100) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vratins please add an issue to add help strings to each argument.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, is this worth an issue given that we will be switching to hydra configs as in #70
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, probably not the same issue, but the config items need to be documented somewhere. |
||
| p.add_argument( | ||
|
vratins marked this conversation as resolved.
|
||
| "--save_gifs", action="store_true", help="Save trajectory GIFs during eval" | ||
| ) | ||
|
|
@@ -578,7 +579,7 @@ def run_eval_sampling( | |
|
|
||
| out = flow_matcher.rk4_integrate( | ||
| graph, | ||
| num_steps=args.rk4_steps, | ||
| num_steps=args.num_steps, | ||
| use_sc=args.use_self_cond, | ||
| device=device, | ||
| return_trajectory=True, | ||
|
|
@@ -675,9 +676,10 @@ def train_epoch( | |
| accumulation_steps=args.grad_accum_steps, | ||
| ) | ||
|
|
||
| if metrics["per_sample_info"] is not None: | ||
| per_sample_losses = metrics["per_sample_info"]["losses"].cpu() | ||
| num_graphs = metrics["per_sample_info"]["num_graphs"] | ||
| per_sample_info = metrics["per_sample_info"] | ||
| if per_sample_info is not None and isinstance(per_sample_info, dict): | ||
| per_sample_losses = per_sample_info["losses"].cpu() | ||
| num_graphs = per_sample_info["num_graphs"] | ||
|
|
||
| if hasattr(batch, "pdb_id"): | ||
| pdb_ids = ( | ||
|
|
@@ -693,8 +695,8 @@ def train_epoch( | |
| logger.warning("=" * 60) | ||
|
|
||
| processed_batches += 1 | ||
| total_loss += metrics["loss"] | ||
| total_rmsd += metrics["rmsd"] | ||
| total_loss += cast(float, metrics["loss"]) | ||
| total_rmsd += cast(float, metrics["rmsd"]) | ||
|
|
||
| # Step optimizer every grad_accum_steps | ||
| if (step + 1) % args.grad_accum_steps == 0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these new names reflected in the Dockerfile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docker file only exposes the input and output arguments, everything else is taken in as arguments by the use (and hydra configs soon)