Skip to content

feat(rf3): chiral gradient zeroing and tracking#182

Merged
k-chrispens merged 2 commits intomainfrom
kmc/chiral-feat-tracking
Mar 23, 2026
Merged

feat(rf3): chiral gradient zeroing and tracking#182
k-chrispens merged 2 commits intomainfrom
kmc/chiral-feat-tracking

Conversation

@k-chrispens
Copy link
Copy Markdown
Collaborator

@k-chrispens k-chrispens commented Mar 20, 2026

Chiral gradients may be why we saw some geometry problems with RF3. This adds the ability to track them for guided and non-guided runs, as well as zero them out so that we can make sure we aren't breaking geometry with those.

Summary by CodeRabbit

  • New Features

    • Added options to disable chiral inputs and to track per-step chiral diagnostics for RF3 models.
    • New CLI flags to enable/disable these behaviors.
    • Featurization now resets per-run chiral state, captures original chiral inputs, and preserves model–atom mapping for reconciliation.
    • When tracking is enabled, per-step chiral gradient metrics are recorded and written to chiral_grad_stats.json.
  • Tests

    • Added end-to-end tests covering chiral feature handling, disabling behavior, and diagnostics collection.

Copilot AI review requested due to automatic review settings March 20, 2026 03:29
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

Adds chiral-feature handling to the RF3 wrapper: new config flags to disable or track chiral inputs, featurization changes to capture/zero chiral tensors and persist model AtomArray, per-step chiral gradient diagnostics, CLI flags, guidance-script integration, and tests for chiral behavior.

Changes

Cohort / File(s) Summary
Core RF3 Model Implementation
src/sampleworks/models/rf3/wrapper.py
Added model_atom_array to RF3Conditioning. Extended RF3Config with disable_chiral_features and track_chiral_features. featurize() resets chiral tracking state, captures original chiral tensors, optionally zeros them, and persists model_atom_array. step() optionally computes per-step chiral gradient diagnostics (EDM-only).
CLI Arguments
src/sampleworks/utils/guidance_script_arguments.py
Added --disable-chiral-features and --track-chiral-features boolean flags to RF3-specific CLI args; minor help/text reflow.
Guidance Script Integration
src/sampleworks/utils/guidance_script_utils.py
When wrapper is RF3Wrapper, calls annotate_structure_for_rf3(...) with chiral flags, resolves step_size fallback, and writes chiral_grad_stats.json when _chiral_grad_stats exists.
Tests
tests/models/test_rf3_chiral_features.py
New GPU/slow test module validating chiral tracking, zeroing behavior, deterministic comparisons between configs, tracking disablement, feature presence, and reset of tracking state on re-featurize.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant GuidanceScript as Guidance Script
    participant RF3Wrapper
    participant EDMModule as EDM Diffusion Module
    participant ChiralCalc as Chiral Gradient Calculator

    User->>GuidanceScript: Run with chiral flags
    GuidanceScript->>GuidanceScript: annotate_structure_for_rf3(disable/track)
    GuidanceScript->>RF3Wrapper: featurize(structure, config)
    activate RF3Wrapper
    RF3Wrapper->>RF3Wrapper: Reset _chiral_grad_stats and tracking state
    RF3Wrapper->>RF3Wrapper: Capture original chiral tensors
    alt disable_chiral_features enabled
        RF3Wrapper->>RF3Wrapper: Zero chiral_centers and chiral_center_dihedral_angles
    end
    RF3Wrapper->>RF3Wrapper: Persist model_atom_array in conditioning
    deactivate RF3Wrapper

    loop for each denoising step
        GuidanceScript->>RF3Wrapper: step(x_t, t, ...)
        activate RF3Wrapper
        alt track_chiral_features enabled and original chiral present
            RF3Wrapper->>EDMModule: Convert x_t to EDM scale
            RF3Wrapper->>ChiralCalc: calc_chiral_grads_flat_impl(coords)
            ChiralCalc-->>RF3Wrapper: per-batch L2 norms
            RF3Wrapper->>RF3Wrapper: Append {t, l2_norm} to _chiral_grad_stats
        end
        deactivate RF3Wrapper
    end

    GuidanceScript->>GuidanceScript: Save outputs
    alt _chiral_grad_stats populated
        GuidanceScript->>GuidanceScript: Write chiral_grad_stats.json
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • marcuscollins

Poem

🐰 I hopped into code with a twitch and a twitch,
Chiral hairs tidy or hidden at which switch.
I count gradient hops as they shimmer and spin,
Zeroing angles or tracking their din.
Hooray — RF3 dances, chirality grinned! 🧪✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main changes: adding chiral gradient tracking and the ability to zero/disable chiral features in RF3 model handling.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch kmc/chiral-feat-tracking

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can get early access to new features in CodeRabbit.

Enable the early_access setting to enable early access features such as new models, tools, and more.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/models/test_rf3_chiral_features.py (1)

87-138: Consider adding strict=True to zip() calls.

While the lengths are asserted equal at line 122, adding strict=True would make the invariant explicit at the call sites and satisfy the static analysis warning. This is a minor improvement.

💡 Suggested fix
-        for e, d in zip(enabled_stats, disabled_stats):
+        for e, d in zip(enabled_stats, disabled_stats, strict=True):
             assert e["t"] == d["t"]
 
         assert enabled_stats[0]["l2_norm"] == pytest.approx(disabled_stats[0]["l2_norm"], rel=1e-4)
 
-        for i, (e, d) in enumerate(zip(enabled_stats, disabled_stats)):
+        for e, d in zip(enabled_stats, disabled_stats, strict=True):
             assert e["l2_norm"] > 0.0
             assert d["l2_norm"] > 0.0
 
-        for i, (e, d) in enumerate(zip(enabled_stats[1:], disabled_stats[1:]), start=1):
+        for i, (e, d) in enumerate(zip(enabled_stats[1:], disabled_stats[1:], strict=True), start=1):
             ratio = e["l2_norm"] / d["l2_norm"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/models/test_rf3_chiral_features.py` around lines 87 - 138, The test
function test_tracking_consistent_across_disabled_and_enabled uses three zip()
calls to iterate paired lists (zip(enabled_stats, disabled_stats) twice and
zip(enabled_stats[1:], disabled_stats[1:], start=1)); although lengths are
asserted equal earlier, update each zip call to include strict=True so the
iterator will raise if lengths differ at runtime and silence the static analysis
warning—i.e., change the three zip(...) invocations inside that function to
zip(..., strict=True).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Around line 87-138: The test function
test_tracking_consistent_across_disabled_and_enabled uses three zip() calls to
iterate paired lists (zip(enabled_stats, disabled_stats) twice and
zip(enabled_stats[1:], disabled_stats[1:], start=1)); although lengths are
asserted equal earlier, update each zip call to include strict=True so the
iterator will raise if lengths differ at runtime and silence the static analysis
warning—i.e., change the three zip(...) invocations inside that function to
zip(..., strict=True).

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 00e543f3-6a55-4a79-a329-35e57c503cbd

📥 Commits

Reviewing files that changed from the base of the PR and between a3dbccb and be1c139.

📒 Files selected for processing (4)
  • src/sampleworks/models/rf3/wrapper.py
  • src/sampleworks/utils/guidance_script_arguments.py
  • src/sampleworks/utils/guidance_script_utils.py
  • tests/models/test_rf3_chiral_features.py

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds RF3-specific support for disabling chiral gradient input features during sampling/guidance and for tracking chiral gradient magnitudes over denoising steps, to help diagnose and mitigate geometry issues potentially driven by chiral gradients.

Changes:

  • Extend RF3 structure annotation/config to include disable_chiral_features and track_chiral_features.
  • Implement chiral feature zeroing in RF3Wrapper.featurize() and chiral gradient L2 tracking in RF3Wrapper.step().
  • Wire new RF3 guidance CLI flags into the guidance script and persist collected stats to chiral_grad_stats.json; add GPU/slow tests covering expected behavior.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
tests/models/test_rf3_chiral_features.py Adds integration-style GPU tests validating chiral feature zeroing and per-step tracking behavior.
src/sampleworks/utils/guidance_script_utils.py Annotates RF3 structures with new chiral config flags; unifies step-size selection; saves chiral_grad_stats.json when present.
src/sampleworks/utils/guidance_script_arguments.py Adds --disable-chiral-features and --track-chiral-features RF3 CLI flags; minor formatting tweaks.
src/sampleworks/models/rf3/wrapper.py Introduces RF3Config fields + annotation plumbing; zeros chiral features when requested; tracks and stores chiral gradient L2 stats per step.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/sampleworks/models/rf3/wrapper.py
chiral gradient (output of calc_chiral_grads_flat_impl on the
EDM scaled coordinates) is the input feature to the model's chiral
processing layer. Uses original features when disable_chiral_features is True to
determine if guidance would be breaking them.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More of a scientific question: do we have a hypothesis what they would do if guidance is breaking them? Exploding maybe? If they're just "out of domain" that is harder to detect, unless we have a very good idea what "in domain" is.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broader point: we should document, when possible, what the signal looks like when it is broken.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we would need to compare guidance off to guidance on for this, but I would expect we would see the magnitudes increase, which would make the feature values go out of domain for the model.

# Track chiral gradient statistics using original features to report what the chiral
# gradient would have been for diagnostics
if (
self._track_chiral_features
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if you've set self._track_chiral_features to True will these other statements be True in passing? I guess it is possible that you'd end up with no chiral centers in the molecule, but that wouldn't be a protein, would it? (Maybe poly-G?)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I say this is that testing more than you have to can leave the impression that the statements you're testing are at least partially independent, which I think these are not.

Comment thread src/sampleworks/models/rf3/wrapper.py
Comment thread src/sampleworks/models/rf3/wrapper.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
tests/models/test_rf3_chiral_features.py (1)

13-31: Use NumPy-style docstrings and typed test helpers

_make_sampler and _run_steps should follow the repo-wide Python interface conventions (NumPy-style docstrings + explicit typed signatures; include jaxtyping shape annotations where tensors are part of the interface).

Proposed refactor
-def _make_sampler(device: torch.device) -> AF3EDMSampler:
+def _make_sampler(device: torch.device) -> AF3EDMSampler:
+    """Create a deterministic EDM sampler for RF3 tests.
+
+    Parameters
+    ----------
+    device : torch.device
+        Device used by the sampler.
+
+    Returns
+    -------
+    AF3EDMSampler
+        Configured sampler with augmentation and alignment disabled.
+    """
     return AF3EDMSampler(EDMSamplerConfig(device=device, augmentation=False, align_to_input=False))
 
 
-def _run_steps(wrapper, features, num_steps: int, seed: int = 42):
-    """Run num_steps of denoising from seeded prior noise. Returns the final state."""
+def _run_steps(
+    wrapper: "RF3WrapperProtocol",
+    features: "RF3FeaturesProtocol",
+    num_steps: int,
+    seed: int = 42,
+):
+    """Run RF3 denoising steps from seeded prior noise.
+
+    Parameters
+    ----------
+    wrapper : RF3WrapperProtocol
+        RF3 wrapper exposing initialization and step APIs.
+    features : RF3FeaturesProtocol
+        Precomputed RF3 conditioning/features bundle.
+    num_steps : int
+        Number of denoising steps to execute.
+    seed : int, optional
+        Seed for reproducible prior noise sampling.
+
+    Returns
+    -------
+    Any
+        Final wrapper state after `num_steps`.
+    """

As per coding guidelines, “Always include NumPy-style docstrings for every function and class.” and “Use jaxtyping for array shape annotations in function signatures (e.g., Float[Tensor, "batch atoms 3"]).”

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/models/test_rf3_chiral_features.py` around lines 13 - 31, Update the
helper functions _make_sampler and _run_steps to follow repo conventions by
adding NumPy-style docstrings and explicit typed signatures: annotate
_make_sampler(device: torch.device) -> AF3EDMSampler with a short description,
Args/Returns sections; annotate _run_steps with precise types for wrapper,
features (use jaxtyping for tensor shapes, e.g., Float[Tensor, "batch ..."]
where applicable), num_steps: int, seed: int = 42 and return type (tensor or
state type), and add a NumPy-style docstring describing parameters, shapes, and
return value; ensure any torch.device usage remains typed and include jaxtyping
imports if missing.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Line 136: The loop uses an unused enumerate index and non-strict zips; replace
the current "for i, (e, d) in enumerate(zip(enabled_stats, disabled_stats))"
pattern by removing enumerate and using strict zipping so that the loop becomes
a simple "for e, d in zip(enabled_stats, disabled_stats, strict=True):" to
eliminate the unused variable i and ensure enabled_stats and disabled_stats have
matching lengths.
- Line 131: Update the loop that iterates over enabled_stats and disabled_stats
in tests/models/test_rf3_chiral_features.py to use zip(..., strict=True) for
defensive iteration safety; specifically modify the for loop that reads "for e,
d in zip(enabled_stats, disabled_stats):" (which follows the len assertion
involving enabled_stats, disabled_stats, and num_steps) so it becomes "for e, d
in zip(enabled_stats, disabled_stats, strict=True):" to enforce equal-length
iteration and fail immediately on any mismatch.

---

Nitpick comments:
In `@tests/models/test_rf3_chiral_features.py`:
- Around line 13-31: Update the helper functions _make_sampler and _run_steps to
follow repo conventions by adding NumPy-style docstrings and explicit typed
signatures: annotate _make_sampler(device: torch.device) -> AF3EDMSampler with a
short description, Args/Returns sections; annotate _run_steps with precise types
for wrapper, features (use jaxtyping for tensor shapes, e.g., Float[Tensor,
"batch ..."] where applicable), num_steps: int, seed: int = 42 and return type
(tensor or state type), and add a NumPy-style docstring describing parameters,
shapes, and return value; ensure any torch.device usage remains typed and
include jaxtyping imports if missing.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f4cd4391-b9e3-4893-a04e-cd5a2ce7de9b

📥 Commits

Reviewing files that changed from the base of the PR and between be1c139 and c19d585.

📒 Files selected for processing (1)
  • tests/models/test_rf3_chiral_features.py

Comment thread tests/models/test_rf3_chiral_features.py
Comment thread tests/models/test_rf3_chiral_features.py
@k-chrispens k-chrispens merged commit 40c12b9 into main Mar 23, 2026
1 of 4 checks passed
@k-chrispens
Copy link
Copy Markdown
Collaborator Author

Tests pass locally, something is up with the lockfile when I try to deploy to our testing server.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants