Skip to content

feat: register pytrees for autogalaxy AnalysisInterferometer #375

@Jammy2211

Description

@Jammy2211

Overview

AnalysisImaging (PyAutoGalaxy) shipped JAX pytree registration for FitImaging + DatasetModel + Galaxies in PR #364 (2026-04-22). The sibling AnalysisInterferometer was left without equivalent registration. Setting use_jax_for_visualization=True (or wrapping analysis.fit_from in jax.jit directly) on an interferometer fit fails the moment a FitInterferometer is returned from a JIT trace, because the return type is not a registered JAX pytree.

This issue lands the registration scaffold on the autogalaxy side. End-to-end verification under jax.jit will follow in the downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which ports the autolens JAX-likelihood interferometer scripts and is explicitly gated on this PR (see PyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.md lines 21-28).

AnalysisQuantity and AnalysisEllipse are out of scope for this issue:

  • AnalysisQuantity will get its own follow-up (no autolens equivalent yet, so the verification path needs design).
  • AnalysisEllipse is structurally different (fit_list_from returns List[FitEllipse] with no Galaxies aggregate, and inherits af.Analysis directly rather than AnalysisDataset). A separate per-class issue if it's ever wanted under JAX.

Plan

  • Lift the Galaxies flatten/unflatten helper out of AnalysisImaging._register_fit_imaging_pytrees into a small shared module so imaging and interferometer register Galaxies identically.
  • Add _register_fit_interferometer_pytrees() to AnalysisInterferometer, registering FitInterferometer (no_flatten = dataset/adapt_images/settings), DatasetModel (defensive — mirrors PyAutoLens), and Galaxies (via the shared helper).
  • Gate the call on self._use_jax from fit_from, mirroring the imaging pattern at autogalaxy/imaging/model/analysis.py:146-147.
  • Run pytest test_autogalaxy/ to confirm the NumPy default path is unaffected.
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary, only repo touched)

Work Classification

Library only. The downstream JAX-likelihood interferometer port is a separate workspace task tracked under PyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.md.

Branch Survey

Repository Current Branch Dirty?
./PyAutoGalaxy main clean

Suggested branch: feature/analysis-interferometer-pytree
Worktree root: ~/Code/PyAutoLabs-wt/analysis-interferometer-pytree/ (created by /start_library)

Implementation Steps

  1. Extract the shared Galaxies pytree helper. Add a new module autogalaxy/analysis/jax_pytrees.py exposing a single register_galaxies_pytree() function. Move the _flatten_galaxies / _unflatten_galaxies block currently inlined in autogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees (lines 195-208) into the helper. Idempotency is preserved via the existing _pytree_registered_classes guard from autoarray.abstract_ndarray.

  2. Update AnalysisImaging to call the shared helper. Replace the inline Galaxies registration block in _register_fit_imaging_pytrees with a single register_galaxies_pytree() call. The FitImaging and DatasetModel registrations stay inline.

  3. Add registration to AnalysisInterferometer (autogalaxy/interferometer/model/analysis.py). Add a static method:

    @staticmethod
    def _register_fit_interferometer_pytrees() -> None:
        from autoarray.abstract_ndarray import register_instance_pytree
        from autoarray.dataset.dataset_model import DatasetModel
        from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree
    
        register_instance_pytree(
            FitInterferometer,
            no_flatten=("dataset", "adapt_images", "settings"),
        )
        register_instance_pytree(DatasetModel)
        register_galaxies_pytree()

    Then in fit_from, add the gate before any other work:

    if self._use_jax:
        self._register_fit_interferometer_pytrees()

    Mirrors autogalaxy/imaging/model/analysis.py:146-147.

  4. Verify pytest test_autogalaxy/ continues to pass. The default NumPy path is unaffected because the registrations are guarded behind self._use_jax. No unit tests should need changes.

  5. Manual JIT smoke check (optional, local). Confirm jax.jit(analysis.fit_from)(instance).figure_of_merit round-trips for a small synthetic interferometer fit. This is not required to ship — the formal end-to-end test lives in the downstream JAX-likelihood port — but a quick local check is cheap insurance against typos.

Spawn-off awareness

Per the prompt's guidance: if any profile / dataset class reachable from FitInterferometer turns out not to be pytree-friendly (e.g. autoarray helpers used by transformer), stop and open a per-class registration issue. Do not add ad-hoc register_pytree_node calls inside _register_fit_interferometer_pytrees.

Key Files

  • PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py — new shared helper module.
  • PyAutoGalaxy/autogalaxy/imaging/model/analysis.py — call shared helper in _register_fit_imaging_pytrees.
  • PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py — add _register_fit_interferometer_pytrees and the _use_jax gate.

Reference: PyAutoLens equivalent

  • PyAutoLens/autolens/interferometer/model/analysis.py:177, 196-214_register_fit_interferometer_pytrees registers FitInterferometer, Tracer (no_flatten=("cosmology",)), DatasetModel. The autogalaxy version registers Galaxies instead of Tracer.

Original Prompt

Click to expand starting prompt

After PyAutoGalaxy AnalysisImaging shipped JAX pytree registration for
FitImaging + DatasetModel + Galaxies (PR #364, 2026-04-22), three sibling Analysis classes
were left without equivalent registration:

  • @PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py::AnalysisInterferometer
  • @PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py::AnalysisEllipse
  • @PyAutoGalaxy/autogalaxy/quantity/model/analysis.py::AnalysisQuantity

Setting use_jax_for_visualization=True on any of these will fail the moment the visualizer
returns a FitInterferometer / FitEllipse / FitQuantity from a jitted fit_from, because
the return type isn't a registered JAX pytree.

Reference patterns

PyAutoLens has the equivalent registrations for the interferometer dataset type — mirror it on
the autogalaxy side:

  • @PyAutoLens/autolens/interferometer/model/analysis.py::AnalysisInterferometer._register_fit_interferometer_pytrees
    (lines 177, 196–214) registers FitInterferometer, DatasetModel, Tracer (no_flatten=("cosmology",)).
    The PyAutoGalaxy version registers Galaxies instead of Tracer.

PyAutoGalaxy's existing autogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees
(lines 169–207) is the closest in-repo template — it registers the Galaxies list container
with explicit _flatten_galaxies / _unflatten_galaxies helpers because
register_instance_pytree alone drops list contents. Any of the three new analyses that hold
a Galaxies will need the same helper — re-export from a shared location or duplicate.

What to do

For each of AnalysisInterferometer, AnalysisEllipse, AnalysisQuantity:

  1. Add _register_fit_*_pytrees() to the class. Call it from __init__ under the existing
    use_jax gate (mirror autogalaxy AnalysisImaging line 147).
  2. Register the dataset's Fit* class via register_instance_pytree. For the inversion-bearing
    ones (Interferometer), check whether the inversion solver state needs no_flatten= — see
    the PyAutoLens reference.
  3. Register DatasetModel via register_instance_pytree(DatasetModel). Confirm the underlying
    helper is idempotent (PyAutoLens AnalysisImaging and AnalysisInterferometer both register it,
    so it is — but verify before assuming).
  4. Register the model container the analysis fits with — Galaxies for Interferometer/Quantity,
    the appropriate isophote container for Ellipse.
  5. Add a workspace_test pilot per dataset type:
    autogalaxy_workspace_test/scripts/{interferometer,ellipse,quantity}/visualization_jax.py
    matching the existing imaging/visualization_jax.py pattern. Each should pass
    use_jax_for_visualization=True and assert fit.png is produced.

Spawn-off awareness

If any profile / dataset class reachable from the new Fit* types isn't yet pytree-friendly,
follow the established spawn-off pattern: stop, open a per-class registration issue, ship that
PR first. Do not paper over with ad-hoc register_pytree_node calls inside the workspace
script. See issued/autogalaxy_workspace_test_jax_likelihood_imaging.md for the pattern.

AnalysisEllipse may surface deeper blockers — its FitEllipse doesn't carry the same
Galaxies structure as the imaging/interferometer analyses. If registration turns out
non-trivial, scope it out into its own issue rather than forcing it into this task.

Scope boundary

  • Eager-JAX path only (Path C in the original issued/fit_imaging_pytree.md framing). Do not
    attempt full jax.jit wrapping — that's gated on the Path A feasibility study.
  • No production workspace adoption — the test workspace pilots are sufficient verification.
  • This task assumes the visualizer dispatch fix from
    autogalaxy/visualizer_fit_for_visualization_dispatch.md has shipped first; otherwise the
    pilot scripts will pass without actually exercising the jit path.

Verification

  • JAX_ENABLE_X64=True python autogalaxy_workspace_test/scripts/{type}/visualization_jax.py
    for each new pilot.
  • Add the three new pilots to autogalaxy_workspace_test/smoke_tests.txt.

Background

Original feature: complete.md entries jax-visualization and mge-jit-visualization
(both 2026-04-19). Imaging pytree registration on autogalaxy: complete.md entry for the
autogalaxy_workspace_test imaging port (2026-04-22). The _register_fit_imaging_pytrees
spawn-off was originally flagged in issued/autogalaxy_workspace_test_jax_likelihood_imaging.md
under "Pytree prerequisite — likely blocker".


Note on scope deviation from the original prompt: the prompt requested registrations for all three sibling analyses (interferometer, ellipse, quantity) plus visualizer pilots. After triage we narrowed this issue to interferometer only:

  • The visualizer-pilot verification path was dropped because the autogalaxy visualizers still call analysis.fit_from(...) directly (the fit_for_visualization dispatch fix has not shipped) — so the pilots would pass without exercising the JIT path. The downstream autogalaxy_workspace_test_jax_likelihood_interferometer task wraps jax.jit(analysis.fit_from) directly and asserts scalar parity, which is a stronger end-to-end test that doesn't depend on the dispatch fix.
  • AnalysisQuantity deferred — no autolens JAX-likelihood quantity equivalent, so verification needs separate design.
  • AnalysisEllipse deferred per the prompt's own scope-out clause (different Fit* shape).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions