Add configurable recycling steps and diffusion step count parameters#205
Add configurable recycling steps and diffusion step count parameters#205marcuscollins merged 3 commits intomainfrom
Conversation
… via GuidanceConfig and annotate_structure_for_* methods
…py and in GuidanceConfig
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds CLI flags for recycling and diffusion steps, extends GuidanceConfig with those fields, threads them into guidance orchestration which dispatches preprocessing for Protenix, RF3, and Boltz (each receiving recycling_steps), and normalizes Boltz recycling default; small comment/type-annotation cleanups in wrappers. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as CLI
participant Parser as ArgParser
participant Guidance as GuidanceUtils
participant Config as GuidanceConfig
participant Wrapper as ModelWrapper
CLI->>Parser: invoke with --recycling-steps, --num-diffusion-steps
Parser->>Guidance: pass parsed args
Guidance->>Config: build GuidanceConfig(recycling_steps, num_diffusion_steps)
Guidance->>Wrapper: dispatch preprocessing (Protenix / RF3 / Boltz) with recycling_steps
Wrapper-->>Guidance: processed_structure
Guidance->>Wrapper: run diffusion with num_diffusion_steps
Wrapper-->>Guidance: results
Guidance-->>CLI: complete/report
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
src/sampleworks/utils/guidance_script_utils.py (1)
235-235: RedundantPath()wrapping.
safe_structure_pathis already aPathobject (assigned on line 233 viaPath(structure_path)). The additionalPath()wrapping is harmless but unnecessary.♻️ Optional simplification
structure = parse( - Path(safe_structure_path), + safe_structure_path, hydrogen_policy="remove", add_missing_atoms=False, ccd_mirror_path=None, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/sampleworks/utils/guidance_script_utils.py` at line 235, The code unnecessarily wraps an existing Path object with Path() when passing safe_structure_path; remove the redundant Path(...) call and pass safe_structure_path directly (update the call site that currently does Path(safe_structure_path)), leaving the originally constructed Path(structure_path) assignment intact and using safe_structure_path as-is.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/sampleworks/utils/guidance_script_utils.py`:
- Around line 28-33: The import block in guidance_script_utils.py duplicates
NoScalingScaler; open the import statement that currently lists
DataSpaceDPSScaler, NoScalingScaler, NoiseSpaceDPSScaler, NoScalingScaler and
remove the redundant NoScalingScaler entry so the import only includes each
scaler once (e.g., DataSpaceDPSScaler, NoiseSpaceDPSScaler, NoScalingScaler).
- Around line 430-453: The call site passes recycling_steps (possibly None) into
process_structure_for_boltz, but process_structure_for_boltz currently requires
an int (default 3) and does not accept None; either make
process_structure_for_boltz accept Optional[int] and handle None by using the
existing default internally, or normalize the value at the call site before
calling process_structure_for_boltz (e.g., compute a fallback int when
getattr(args, "recycling_steps", None) is None) so the function always receives
an int; update the function signature in sampleworks.models.boltz.wrapper
(process_structure_for_boltz) or adjust the call in guidance_script_utils.py
where process_structure_for_boltz is invoked to ensure types align.
---
Nitpick comments:
In `@src/sampleworks/utils/guidance_script_utils.py`:
- Line 235: The code unnecessarily wraps an existing Path object with Path()
when passing safe_structure_path; remove the redundant Path(...) call and pass
safe_structure_path directly (update the call site that currently does
Path(safe_structure_path)), leaving the originally constructed
Path(structure_path) assignment intact and using safe_structure_path as-is.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6e287086-7506-4b89-bbc4-1cb6180711aa
📒 Files selected for processing (4)
run_grid_search.pysrc/sampleworks/models/rf3/wrapper.pysrc/sampleworks/utils/guidance_script_arguments.pysrc/sampleworks/utils/guidance_script_utils.py
💤 Files with no reviewable changes (1)
- src/sampleworks/models/rf3/wrapper.py
k-chrispens
left a comment
There was a problem hiding this comment.
Generally lgtm after addressing ty comments, there are a little comment things that might be useful though!
| from sampleworks.core.scalers.pure_guidance import PureGuidance | ||
| from sampleworks.core.scalers.step_scalers import ( | ||
| DataSpaceDPSScaler, | ||
| NoScalingScaler, |
There was a problem hiding this comment.
yeah this appears to be a duplicate import
| elif "Boltz" in wrapper_class_name: | ||
| from sampleworks.models.boltz.wrapper import process_structure_for_boltz | ||
| structure = process_structure_for_boltz( | ||
| structure, ensemble_size=args.ensemble_size, recycling_steps=recycling_steps |
There was a problem hiding this comment.
Presumably you will address this in getting the ty updates in? But if not this should be addressed, was an oversight on my part
| "--recycling-steps", | ||
| type=int, | ||
| default=None, | ||
| help="Number of recycling steps for model inference (if not specified, uses model default)", |
There was a problem hiding this comment.
We could point to where users could find the model default here?
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/sampleworks/utils/guidance_script_utils.py (1)
491-498:⚠️ Potential issue | 🟠 MajorGuard
num_diffusion_stepsbefore using it.This is now user-configurable, so
0will raise on the divisions below, and plainargparse.Namespacecallers can also miss the field entirely even though_run_guidance()accepts them.Suggested fix
- num_steps = args.num_diffusion_steps + num_steps = getattr(args, "num_diffusion_steps", 200) + if num_steps <= 0: + raise ValueError("num_diffusion_steps must be > 0")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/sampleworks/utils/guidance_script_utils.py` around lines 491 - 498, The code uses args.num_diffusion_steps (assigned to num_steps) without guarding for missing or zero values, causing divisions by zero or AttributeError when callers omit the field; update _run_guidance (or the caller) to validate num_steps before using it: ensure args has attribute num_diffusion_steps and that num_steps is a positive int > 0 (or set a sensible default), and if invalid either set num_steps to a default value or raise a clear ValueError; then compute guidance_t_start and t_start only after this guard so the divisions (in guidance_t_start and t_start calculations) cannot divide by zero or None.src/sampleworks/models/boltz/wrapper.py (1)
358-367:⚠️ Potential issue | 🟠 MajorReject non-positive
recycling_steps.
recycling_stepsis now user-controlled, but0/negative values still flow intoBoltzConfig. Both_pairformer_pass()implementations later dofor _ in range(recycling_steps), so those values skip the Pairformer trunk entirely and produce bogus conditioning instead of a valid “no recycling” mode.Suggested fix
if recycling_steps is None: recycling_steps = 3 + if recycling_steps < 1: + raise ValueError("recycling_steps must be >= 1") config = BoltzConfig( out_dir=out_dir or structure.get("metadata", {}).get("id", "boltz_output"),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/sampleworks/models/boltz/wrapper.py` around lines 358 - 367, The code accepts user-controlled recycling_steps but allows 0/negative values which break the Pairformer loops; add a validation just before constructing BoltzConfig to reject non-positive inputs (recycling_steps <= 0) by raising a clear ValueError (mention recycling_steps and expected positive integer) so invalid inputs don't flow into BoltzConfig; keep the existing None->3 default logic but ensure the check runs after that normalization and before BoltzConfig is instantiated (reference recycling_steps, BoltzConfig, and the _pairformer_pass() usage).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@src/sampleworks/models/boltz/wrapper.py`:
- Around line 358-367: The code accepts user-controlled recycling_steps but
allows 0/negative values which break the Pairformer loops; add a validation just
before constructing BoltzConfig to reject non-positive inputs (recycling_steps
<= 0) by raising a clear ValueError (mention recycling_steps and expected
positive integer) so invalid inputs don't flow into BoltzConfig; keep the
existing None->3 default logic but ensure the check runs after that
normalization and before BoltzConfig is instantiated (reference recycling_steps,
BoltzConfig, and the _pairformer_pass() usage).
In `@src/sampleworks/utils/guidance_script_utils.py`:
- Around line 491-498: The code uses args.num_diffusion_steps (assigned to
num_steps) without guarding for missing or zero values, causing divisions by
zero or AttributeError when callers omit the field; update _run_guidance (or the
caller) to validate num_steps before using it: ensure args has attribute
num_diffusion_steps and that num_steps is a positive int > 0 (or set a sensible
default), and if invalid either set num_steps to a default value or raise a
clear ValueError; then compute guidance_t_start and t_start only after this
guard so the divisions (in guidance_t_start and t_start calculations) cannot
divide by zero or None.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2514165b-7bcd-4d1a-a2a7-8bac4a7c4cf3
📒 Files selected for processing (3)
run_grid_search.pysrc/sampleworks/models/boltz/wrapper.pysrc/sampleworks/utils/guidance_script_utils.py
🚧 Files skipped from review as they are similar to previous changes (1)
- run_grid_search.py
68f0198 to
1c0bb41
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@run_grid_search.py`:
- Around line 471-473: The parser is registering the same option string
"--num-diffusion-steps" twice which causes argparse.ArgumentError; locate the
duplicate parser.add_argument call that uses "--num-diffusion-steps" and remove
or rename the redundant registration (or consolidate defaults/help) so the
option is only added once (look for the parser.add_argument usage that defines
"--num-diffusion-steps" and eliminate the duplicate).
- Around line 449-462: The CLI flags --num-diffusion-steps and --recycling-steps
must be validated at parse time: ensure --num-diffusion-steps
(parser.add_argument with name "num_diffusion_steps") is a positive integer (>0)
and that --recycling-steps (name "recycling_steps") is not negative (>=0) before
passing them into model wrappers or the schedule math; add validation
immediately after parsing (or use a custom argparse type) and call
parser.error(...) with a clear message when values are invalid so malformed
values never reach guidance_script_utils.py.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a5a35423-e599-4f70-a980-e29058ec49f1
📒 Files selected for processing (5)
run_grid_search.pysrc/sampleworks/models/boltz/wrapper.pysrc/sampleworks/models/protenix/wrapper.pysrc/sampleworks/models/rf3/wrapper.pysrc/sampleworks/utils/guidance_script_utils.py
✅ Files skipped from review due to trivial changes (2)
- src/sampleworks/models/rf3/wrapper.py
- src/sampleworks/models/protenix/wrapper.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/sampleworks/models/boltz/wrapper.py
1c0bb41 to
1313b6e
Compare
1313b6e to
ec53661
Compare
ec53661 to
58191fe
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/sampleworks/utils/guidance_script_utils.py`:
- Around line 434-435: Validate that args.guidance_start and
args.partial_diffusion_step (and if present args.guidance_end) are within the
bounds of args.num_diffusion_steps: they must be >= 0 and <=
args.num_diffusion_steps so the computed fractions (used later) never exceed
1.0; if args.num_diffusion_steps is provided, add checks near the existing
args.num_diffusion_steps validation to raise a ValueError when
guidance_start/guidance_end/partial_diffusion_step are out of range or logically
inconsistent (e.g., guidance_end < guidance_start), referencing the argument
names guidance_start, guidance_end, partial_diffusion_step, and
num_diffusion_steps to locate and fix the checks.
- Around line 431-435: The code raises AttributeError by accessing
args.num_diffusion_steps directly; replace any direct accesses (e.g., the check
in the block that currently reads if args.num_diffusion_steps is not None and
args.num_diffusion_steps <= 0) with a safe getattr call like num_diffusion_steps
= getattr(args, "num_diffusion_steps", 200) and then perform the validation
against that local variable (raise ValueError if num_diffusion_steps <= 0) so it
matches GuidanceConfig's default and keeps argparse-based callers working;
update all occurrences that reference args.num_diffusion_steps to use this local
num_diffusion_steps variable.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 49bf5a4b-619d-4e01-8cb7-90d9cd5fe0dd
📒 Files selected for processing (4)
run_grid_search.pysrc/sampleworks/models/boltz/wrapper.pysrc/sampleworks/models/protenix/wrapper.pysrc/sampleworks/utils/guidance_script_utils.py
✅ Files skipped from review due to trivial changes (1)
- src/sampleworks/models/protenix/wrapper.py
🚧 Files skipped from review as they are similar to previous changes (2)
- run_grid_search.py
- src/sampleworks/models/boltz/wrapper.py
| recycling_steps = getattr(args, "recycling_steps", None) | ||
| if recycling_steps is not None and recycling_steps <= 0: | ||
| raise ValueError("recycling_steps must be > 0") | ||
| if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0: | ||
| raise ValueError("num_diffusion_steps must be > 0") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Inspect call sites that may pass plain namespaces into run_guidance/_run_guidance
rg -n -C2 '\b(run_guidance|_run_guidance)\s*\(' --type py
rg -n -C2 'argparse\.Namespace\s*\(' --type py
rg -n -C2 'SimpleNamespace\s*\(' --type pyRepository: diff-use/sampleworks
Length of output: 4725
🏁 Script executed:
#!/bin/bash
# Check one of the scripts to see how num_diffusion_steps is defined in argparse
head -50 scripts/protenix_fk_steering.py
cat scripts/rf3_pure_guidance.py
# Look for GuidanceConfig definition
rg -n "class GuidanceConfig" --type py -A 20
# View the actual problematic code
sed -n '429,440p' src/sampleworks/utils/guidance_script_utils.py
sed -n '495,505p' src/sampleworks/utils/guidance_script_utils.pyRepository: diff-use/sampleworks
Length of output: 4296
🏁 Script executed:
#!/bin/bash
# Get the full GuidanceConfig definition
sed -n '116,160p' src/sampleworks/utils/guidance_script_arguments.py
# Search for where num_diffusion_steps is added to parsers
rg -n "num_diffusion_steps" --type py -B 2 -A 2
# Check the parser functions
rg -n "def parse_" --type py -A 30 src/sampleworks/utils/guidance_script_arguments.py | head -100Repository: diff-use/sampleworks
Length of output: 6721
🏁 Script executed:
#!/bin/bash
# Find add_generic_args definition
rg -n "def add_generic_args" --type py -A 50 src/sampleworks/utils/guidance_script_arguments.pyRepository: diff-use/sampleworks
Length of output: 2079
🏁 Script executed:
#!/bin/bash
# Search for num_diffusion_steps in all argument setup functions
rg -n "num.diffusion.steps\|num_diffusion_steps" --type py src/sampleworks/utils/guidance_script_arguments.py
# Check if it's added in any of the model-specific or guidance-specific arg functions
rg -n "def add_" --type py src/sampleworks/utils/guidance_script_arguments.pyRepository: diff-use/sampleworks
Length of output: 710
🏁 Script executed:
#!/bin/bash
# Check all the add_* functions to be thorough
rg -n "def add_pure_guidance_args" --type py -A 20 src/sampleworks/utils/guidance_script_arguments.py
rg -n "def add_fk_steering_args" --type py -A 20 src/sampleworks/utils/guidance_script_arguments.py
# Also check if there are any tests that might pass Namespace directly
rg -n "_run_guidance\|run_guidance" --type py tests/ 2>/dev/null || echo "No tests dir or no matches"Repository: diff-use/sampleworks
Length of output: 1584
🏁 Script executed:
#!/bin/bash
# Double-check: search for num_diffusion_steps anywhere in the entire file
rg "num.diffusion.steps\|num_diffusion_steps" --type py src/sampleworks/utils/guidance_script_arguments.py
# And check what the scripts would actually get
python3 << 'EOF'
import sys
sys.path.insert(0, 'src')
# Try simulating what one of the parsers returns
import argparse
from sampleworks.utils.guidance_script_arguments import add_generic_args, add_pure_guidance_args
parser = argparse.ArgumentParser()
add_generic_args(parser)
add_pure_guidance_args(parser)
# Simulate parsing with minimal required args
try:
args = parser.parse_args(['--structure', 'test.pdb', '--density', 'test.mrc', '--resolution', '2.0'])
print("Parsed args:", vars(args))
print("Has num_diffusion_steps?", hasattr(args, 'num_diffusion_steps'))
except Exception as e:
print(f"Error: {e}")
EOFRepository: diff-use/sampleworks
Length of output: 467
🏁 Script executed:
#!/bin/bash
# Get exact lines with context
sed -n '430,440p' src/sampleworks/utils/guidance_script_utils.py
echo "---"
sed -n '495,505p' src/sampleworks/utils/guidance_script_utils.py
# Check if num_diffusion_steps is set somewhere before _run_guidance call
grep -n "num_diffusion_steps" src/sampleworks/utils/guidance_script_utils.pyRepository: diff-use/sampleworks
Length of output: 1285
Use getattr() to safely access num_diffusion_steps from args.
Line 434 and 499 assume num_diffusion_steps always exists on the args object. However, the argparse parsers (used by all scripts in the repo) do not define this argument, so argparse.Namespace instances lack this attribute. This causes an immediate AttributeError when the code runs. Use getattr(args, "num_diffusion_steps", 200) to match the default in GuidanceConfig and maintain backward compatibility with existing argparse-based callers.
Proposed fix
recycling_steps = getattr(args, "recycling_steps", None)
if recycling_steps is not None and recycling_steps <= 0:
raise ValueError("recycling_steps must be > 0")
- if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0:
+ num_diffusion_steps = getattr(args, "num_diffusion_steps", 200)
+ if num_diffusion_steps is None:
+ num_diffusion_steps = 200
+ if num_diffusion_steps <= 0:
raise ValueError("num_diffusion_steps must be > 0")
@@
- num_steps = args.num_diffusion_steps
+ num_steps = num_diffusion_steps📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| recycling_steps = getattr(args, "recycling_steps", None) | |
| if recycling_steps is not None and recycling_steps <= 0: | |
| raise ValueError("recycling_steps must be > 0") | |
| if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0: | |
| raise ValueError("num_diffusion_steps must be > 0") | |
| recycling_steps = getattr(args, "recycling_steps", None) | |
| if recycling_steps is not None and recycling_steps <= 0: | |
| raise ValueError("recycling_steps must be > 0") | |
| num_diffusion_steps = getattr(args, "num_diffusion_steps", 200) | |
| if num_diffusion_steps is None: | |
| num_diffusion_steps = 200 | |
| if num_diffusion_steps <= 0: | |
| raise ValueError("num_diffusion_steps must be > 0") |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/sampleworks/utils/guidance_script_utils.py` around lines 431 - 435, The
code raises AttributeError by accessing args.num_diffusion_steps directly;
replace any direct accesses (e.g., the check in the block that currently reads
if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0) with
a safe getattr call like num_diffusion_steps = getattr(args,
"num_diffusion_steps", 200) and then perform the validation against that local
variable (raise ValueError if num_diffusion_steps <= 0) so it matches
GuidanceConfig's default and keeps argparse-based callers working; update all
occurrences that reference args.num_diffusion_steps to use this local
num_diffusion_steps variable.
| if args.num_diffusion_steps is not None and args.num_diffusion_steps <= 0: | ||
| raise ValueError("num_diffusion_steps must be > 0") |
There was a problem hiding this comment.
Guard guidance_start and partial_diffusion_step against a smaller num_steps.
Once num_diffusion_steps is user-configurable, values that used to be safe against the old hard-coded 200 can now exceed the new ceiling. That makes the fractions computed at Lines 505-506 and 534-536 greater than 1.0, which is an invalid diffusion schedule. Reject those combinations up front.
💡 Proposed fix
if num_steps <= 0:
raise ValueError("num_diffusion_steps must be > 0")
+ guidance_start = getattr(args, "guidance_start", -1)
+ if guidance_start > num_steps:
+ raise ValueError("guidance_start must be <= num_diffusion_steps")
+ partial_diffusion_step = getattr(args, "partial_diffusion_step", None)
+ if partial_diffusion_step is not None and partial_diffusion_step > num_steps:
+ raise ValueError("partial_diffusion_step must be <= num_diffusion_steps")Also applies to: 499-499
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/sampleworks/utils/guidance_script_utils.py` around lines 434 - 435,
Validate that args.guidance_start and args.partial_diffusion_step (and if
present args.guidance_end) are within the bounds of args.num_diffusion_steps:
they must be >= 0 and <= args.num_diffusion_steps so the computed fractions
(used later) never exceed 1.0; if args.num_diffusion_steps is provided, add
checks near the existing args.num_diffusion_steps validation to raise a
ValueError when guidance_start/guidance_end/partial_diffusion_step are out of
range or logically inconsistent (e.g., guidance_end < guidance_start),
referencing the argument names guidance_start, guidance_end,
partial_diffusion_step, and num_diffusion_steps to locate and fix the checks.
Add --recycling-steps and --num-diffusion-steps to add_generic_args() and _DYNAMIC_ATTRS so they are available via sampleworks-guidance. These were added in #205 but not wired to the CLI.
This PR adds two options to control diffusion: the total number of steps, and the number of recycling steps. These were used previously but with hard-coded values which are now surfaced to the user. They are passed to the respective models by use of the annotate* methods which append keyword config values to the structure dictionary.
Summary by CodeRabbit
New Features
Bug Fixes
Chores