feat(rf3): chiral gradient zeroing and tracking#182
Conversation
📝 WalkthroughWalkthroughAdds chiral-feature handling to the RF3 wrapper: new config flags to disable or track chiral inputs, featurization changes to capture/zero chiral tensors and persist model AtomArray, per-step chiral gradient diagnostics, CLI flags, guidance-script integration, and tests for chiral behavior. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant GuidanceScript as Guidance Script
participant RF3Wrapper
participant EDMModule as EDM Diffusion Module
participant ChiralCalc as Chiral Gradient Calculator
User->>GuidanceScript: Run with chiral flags
GuidanceScript->>GuidanceScript: annotate_structure_for_rf3(disable/track)
GuidanceScript->>RF3Wrapper: featurize(structure, config)
activate RF3Wrapper
RF3Wrapper->>RF3Wrapper: Reset _chiral_grad_stats and tracking state
RF3Wrapper->>RF3Wrapper: Capture original chiral tensors
alt disable_chiral_features enabled
RF3Wrapper->>RF3Wrapper: Zero chiral_centers and chiral_center_dihedral_angles
end
RF3Wrapper->>RF3Wrapper: Persist model_atom_array in conditioning
deactivate RF3Wrapper
loop for each denoising step
GuidanceScript->>RF3Wrapper: step(x_t, t, ...)
activate RF3Wrapper
alt track_chiral_features enabled and original chiral present
RF3Wrapper->>EDMModule: Convert x_t to EDM scale
RF3Wrapper->>ChiralCalc: calc_chiral_grads_flat_impl(coords)
ChiralCalc-->>RF3Wrapper: per-batch L2 norms
RF3Wrapper->>RF3Wrapper: Append {t, l2_norm} to _chiral_grad_stats
end
deactivate RF3Wrapper
end
GuidanceScript->>GuidanceScript: Save outputs
alt _chiral_grad_stats populated
GuidanceScript->>GuidanceScript: Write chiral_grad_stats.json
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can get early access to new features in CodeRabbit.Enable the |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/models/test_rf3_chiral_features.py (1)
87-138: Consider addingstrict=Truetozip()calls.While the lengths are asserted equal at line 122, adding
strict=Truewould make the invariant explicit at the call sites and satisfy the static analysis warning. This is a minor improvement.💡 Suggested fix
- for e, d in zip(enabled_stats, disabled_stats): + for e, d in zip(enabled_stats, disabled_stats, strict=True): assert e["t"] == d["t"] assert enabled_stats[0]["l2_norm"] == pytest.approx(disabled_stats[0]["l2_norm"], rel=1e-4) - for i, (e, d) in enumerate(zip(enabled_stats, disabled_stats)): + for e, d in zip(enabled_stats, disabled_stats, strict=True): assert e["l2_norm"] > 0.0 assert d["l2_norm"] > 0.0 - for i, (e, d) in enumerate(zip(enabled_stats[1:], disabled_stats[1:]), start=1): + for i, (e, d) in enumerate(zip(enabled_stats[1:], disabled_stats[1:], strict=True), start=1): ratio = e["l2_norm"] / d["l2_norm"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/models/test_rf3_chiral_features.py` around lines 87 - 138, The test function test_tracking_consistent_across_disabled_and_enabled uses three zip() calls to iterate paired lists (zip(enabled_stats, disabled_stats) twice and zip(enabled_stats[1:], disabled_stats[1:], start=1)); although lengths are asserted equal earlier, update each zip call to include strict=True so the iterator will raise if lengths differ at runtime and silence the static analysis warning—i.e., change the three zip(...) invocations inside that function to zip(..., strict=True).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Around line 87-138: The test function
test_tracking_consistent_across_disabled_and_enabled uses three zip() calls to
iterate paired lists (zip(enabled_stats, disabled_stats) twice and
zip(enabled_stats[1:], disabled_stats[1:], start=1)); although lengths are
asserted equal earlier, update each zip call to include strict=True so the
iterator will raise if lengths differ at runtime and silence the static analysis
warning—i.e., change the three zip(...) invocations inside that function to
zip(..., strict=True).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 00e543f3-6a55-4a79-a329-35e57c503cbd
📒 Files selected for processing (4)
src/sampleworks/models/rf3/wrapper.pysrc/sampleworks/utils/guidance_script_arguments.pysrc/sampleworks/utils/guidance_script_utils.pytests/models/test_rf3_chiral_features.py
There was a problem hiding this comment.
Pull request overview
This PR adds RF3-specific support for disabling chiral gradient input features during sampling/guidance and for tracking chiral gradient magnitudes over denoising steps, to help diagnose and mitigate geometry issues potentially driven by chiral gradients.
Changes:
- Extend RF3 structure annotation/config to include
disable_chiral_featuresandtrack_chiral_features. - Implement chiral feature zeroing in
RF3Wrapper.featurize()and chiral gradient L2 tracking inRF3Wrapper.step(). - Wire new RF3 guidance CLI flags into the guidance script and persist collected stats to
chiral_grad_stats.json; add GPU/slow tests covering expected behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| tests/models/test_rf3_chiral_features.py | Adds integration-style GPU tests validating chiral feature zeroing and per-step tracking behavior. |
| src/sampleworks/utils/guidance_script_utils.py | Annotates RF3 structures with new chiral config flags; unifies step-size selection; saves chiral_grad_stats.json when present. |
| src/sampleworks/utils/guidance_script_arguments.py | Adds --disable-chiral-features and --track-chiral-features RF3 CLI flags; minor formatting tweaks. |
| src/sampleworks/models/rf3/wrapper.py | Introduces RF3Config fields + annotation plumbing; zeros chiral features when requested; tracks and stores chiral gradient L2 stats per step. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| chiral gradient (output of calc_chiral_grads_flat_impl on the | ||
| EDM scaled coordinates) is the input feature to the model's chiral | ||
| processing layer. Uses original features when disable_chiral_features is True to | ||
| determine if guidance would be breaking them. |
There was a problem hiding this comment.
More of a scientific question: do we have a hypothesis what they would do if guidance is breaking them? Exploding maybe? If they're just "out of domain" that is harder to detect, unless we have a very good idea what "in domain" is.
There was a problem hiding this comment.
Broader point: we should document, when possible, what the signal looks like when it is broken.
There was a problem hiding this comment.
I suppose we would need to compare guidance off to guidance on for this, but I would expect we would see the magnitudes increase, which would make the feature values go out of domain for the model.
| # Track chiral gradient statistics using original features to report what the chiral | ||
| # gradient would have been for diagnostics | ||
| if ( | ||
| self._track_chiral_features |
There was a problem hiding this comment.
Nit: if you've set self._track_chiral_features to True will these other statements be True in passing? I guess it is possible that you'd end up with no chiral centers in the molecule, but that wouldn't be a protein, would it? (Maybe poly-G?)
There was a problem hiding this comment.
The reason I say this is that testing more than you have to can leave the impression that the statements you're testing are at least partially independent, which I think these are not.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/models/test_rf3_chiral_features.py (1)
13-31: Use NumPy-style docstrings and typed test helpers
_make_samplerand_run_stepsshould follow the repo-wide Python interface conventions (NumPy-style docstrings + explicit typed signatures; include jaxtyping shape annotations where tensors are part of the interface).Proposed refactor
-def _make_sampler(device: torch.device) -> AF3EDMSampler: +def _make_sampler(device: torch.device) -> AF3EDMSampler: + """Create a deterministic EDM sampler for RF3 tests. + + Parameters + ---------- + device : torch.device + Device used by the sampler. + + Returns + ------- + AF3EDMSampler + Configured sampler with augmentation and alignment disabled. + """ return AF3EDMSampler(EDMSamplerConfig(device=device, augmentation=False, align_to_input=False)) -def _run_steps(wrapper, features, num_steps: int, seed: int = 42): - """Run num_steps of denoising from seeded prior noise. Returns the final state.""" +def _run_steps( + wrapper: "RF3WrapperProtocol", + features: "RF3FeaturesProtocol", + num_steps: int, + seed: int = 42, +): + """Run RF3 denoising steps from seeded prior noise. + + Parameters + ---------- + wrapper : RF3WrapperProtocol + RF3 wrapper exposing initialization and step APIs. + features : RF3FeaturesProtocol + Precomputed RF3 conditioning/features bundle. + num_steps : int + Number of denoising steps to execute. + seed : int, optional + Seed for reproducible prior noise sampling. + + Returns + ------- + Any + Final wrapper state after `num_steps`. + """As per coding guidelines, “Always include NumPy-style docstrings for every function and class.” and “Use jaxtyping for array shape annotations in function signatures (e.g.,
Float[Tensor, "batch atoms 3"]).”🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/models/test_rf3_chiral_features.py` around lines 13 - 31, Update the helper functions _make_sampler and _run_steps to follow repo conventions by adding NumPy-style docstrings and explicit typed signatures: annotate _make_sampler(device: torch.device) -> AF3EDMSampler with a short description, Args/Returns sections; annotate _run_steps with precise types for wrapper, features (use jaxtyping for tensor shapes, e.g., Float[Tensor, "batch ..."] where applicable), num_steps: int, seed: int = 42 and return type (tensor or state type), and add a NumPy-style docstring describing parameters, shapes, and return value; ensure any torch.device usage remains typed and include jaxtyping imports if missing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Line 136: The loop uses an unused enumerate index and non-strict zips; replace
the current "for i, (e, d) in enumerate(zip(enabled_stats, disabled_stats))"
pattern by removing enumerate and using strict zipping so that the loop becomes
a simple "for e, d in zip(enabled_stats, disabled_stats, strict=True):" to
eliminate the unused variable i and ensure enabled_stats and disabled_stats have
matching lengths.
- Line 131: Update the loop that iterates over enabled_stats and disabled_stats
in tests/models/test_rf3_chiral_features.py to use zip(..., strict=True) for
defensive iteration safety; specifically modify the for loop that reads "for e,
d in zip(enabled_stats, disabled_stats):" (which follows the len assertion
involving enabled_stats, disabled_stats, and num_steps) so it becomes "for e, d
in zip(enabled_stats, disabled_stats, strict=True):" to enforce equal-length
iteration and fail immediately on any mismatch.
---
Nitpick comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Around line 13-31: Update the helper functions _make_sampler and _run_steps to
follow repo conventions by adding NumPy-style docstrings and explicit typed
signatures: annotate _make_sampler(device: torch.device) -> AF3EDMSampler with a
short description, Args/Returns sections; annotate _run_steps with precise types
for wrapper, features (use jaxtyping for tensor shapes, e.g., Float[Tensor,
"batch ..."] where applicable), num_steps: int, seed: int = 42 and return type
(tensor or state type), and add a NumPy-style docstring describing parameters,
shapes, and return value; ensure any torch.device usage remains typed and
include jaxtyping imports if missing.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f4cd4391-b9e3-4893-a04e-cd5a2ce7de9b
📒 Files selected for processing (1)
tests/models/test_rf3_chiral_features.py
|
Tests pass locally, something is up with the lockfile when I try to deploy to our testing server. |
Chiral gradients may be why we saw some geometry problems with RF3. This adds the ability to track them for guided and non-guided runs, as well as zero them out so that we can make sure we aren't breaking geometry with those.
Summary by CodeRabbit
New Features
Tests