Skip to content

feat: register pytrees for AnalysisInterferometer#376

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/analysis-interferometer-pytree
Apr 28, 2026
Merged

feat: register pytrees for AnalysisInterferometer#376
Jammy2211 merged 1 commit into
mainfrom
feature/analysis-interferometer-pytree

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds JAX pytree registration for AnalysisInterferometer so jax.jit(analysis.fit_from) can flatten its FitInterferometer return value — mirrors the existing AnalysisImaging pattern shipped in #364.

The Galaxies flatten/unflatten block (~12 lines, identical between imaging and interferometer) is lifted into autogalaxy.analysis.jax_pytrees.register_galaxies_pytree() so both analyses share the non-trivial logic without duplication. Imaging is refactored to call the shared helper; the per-analysis register_instance_pytree(Fit*, ...) and register_instance_pytree(DatasetModel) lines stay inline at each call site so it's still obvious from each method what is being registered.

AnalysisQuantity and AnalysisEllipse are out of scope — quantity has no autolens JAX-likelihood equivalent yet (verification path needs separate design) and ellipse is structurally different (returns List[FitEllipse] with no Galaxies aggregate, inherits af.Analysis directly). Both deferred to follow-up issues.

End-to-end JIT verification (jax.jit(analysis.fit_from) round-trip with NumPy parity) lands in the downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which is explicitly gated on this PR.

Closes #375

API Changes

  • New: autogalaxy.analysis.jax_pytrees.register_galaxies_pytree() — shared helper that registers Galaxies (a list subclass) as a JAX pytree with custom flatten/unflatten. Idempotent.
  • New: AnalysisInterferometer._register_fit_interferometer_pytrees() — static method registering FitInterferometer, DatasetModel, and Galaxies. Called from fit_from under the existing self._use_jax gate.
  • Refactored: AnalysisImaging._register_fit_imaging_pytrees — body collapsed from 41 to 11 lines by delegating the Galaxies registration to the shared helper. Behaviour unchanged.

See full details below.

Test Plan

  • pytest test_autogalaxy/imaging/model/test_analysis_imaging.py — passes.
  • pytest test_autogalaxy/interferometer/model/test_analysis_interferometer.py — passes.
  • Manual smoke check: _register_fit_interferometer_pytrees() runs without error, is idempotent on repeated calls, and ends with FitInterferometer, DatasetModel, and Galaxies in _pytree_registered_classes.
  • End-to-end JIT verification (deferred — lands in downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which is gated on this PR).
Full API Changes (for automation & release notes)

Added

  • autogalaxy.analysis.jax_pytrees (new module) exposing register_galaxies_pytree() -> None. Registers Galaxies as a JAX pytree with custom flatten/unflatten. Idempotent via _pytree_registered_classes.
  • autogalaxy.interferometer.model.analysis.AnalysisInterferometer._register_fit_interferometer_pytrees() — staticmethod registering FitInterferometer (no_flatten=("dataset", "adapt_images", "settings")), DatasetModel, and Galaxies (via the shared helper).
  • AnalysisInterferometer.fit_from now calls _register_fit_interferometer_pytrees() before constructing the fit, gated on self._use_jax (mirrors imaging line 146-147).

Changed Behaviour

  • AnalysisImaging._register_fit_imaging_pytrees — internal refactor only. Body collapsed from 41 to 11 lines; the Galaxies registration block now delegates to register_galaxies_pytree(). Net behaviour identical.

Migration

  • None. New registration is opt-in via use_jax=True on AnalysisInterferometer (default). Existing NumPy callers are unaffected.

🤖 Generated with Claude Code

Mirror AnalysisImaging's pytree registration on the interferometer side so
jax.jit(fit_from) can flatten its FitInterferometer return value. Extract
the Galaxies flatten/unflatten block (~12 lines, identical across analyses)
into autogalaxy.analysis.jax_pytrees.register_galaxies_pytree() so imaging
and interferometer share the non-trivial logic without duplication.

End-to-end JIT verification (jax.jit(analysis.fit_from) round-trip with
NumPy parity) will land in the downstream
autogalaxy_workspace_test_jax_likelihood_interferometer task, which is
explicitly gated on this PR.

Refs #375

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label Apr 28, 2026
@Jammy2211 Jammy2211 merged commit b9c09af into main Apr 28, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/analysis-interferometer-pytree branch April 28, 2026 16:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat: register pytrees for autogalaxy AnalysisInterferometer

1 participant