Overview
AnalysisImaging (PyAutoGalaxy) shipped JAX pytree registration for FitImaging + DatasetModel + Galaxies in PR #364 (2026-04-22). The sibling AnalysisInterferometer was left without equivalent registration. Setting use_jax_for_visualization=True (or wrapping analysis.fit_from in jax.jit directly) on an interferometer fit fails the moment a FitInterferometer is returned from a JIT trace, because the return type is not a registered JAX pytree.
This issue lands the registration scaffold on the autogalaxy side. End-to-end verification under jax.jit will follow in the downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which ports the autolens JAX-likelihood interferometer scripts and is explicitly gated on this PR (see PyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.md lines 21-28).
AnalysisQuantity and AnalysisEllipse are out of scope for this issue:
AnalysisQuantity will get its own follow-up (no autolens equivalent yet, so the verification path needs design).
AnalysisEllipse is structurally different (fit_list_from returns List[FitEllipse] with no Galaxies aggregate, and inherits af.Analysis directly rather than AnalysisDataset). A separate per-class issue if it's ever wanted under JAX.
Plan
- Lift the
Galaxies flatten/unflatten helper out of AnalysisImaging._register_fit_imaging_pytrees into a small shared module so imaging and interferometer register Galaxies identically.
- Add
_register_fit_interferometer_pytrees() to AnalysisInterferometer, registering FitInterferometer (no_flatten = dataset/adapt_images/settings), DatasetModel (defensive — mirrors PyAutoLens), and Galaxies (via the shared helper).
- Gate the call on
self._use_jax from fit_from, mirroring the imaging pattern at autogalaxy/imaging/model/analysis.py:146-147.
- Run
pytest test_autogalaxy/ to confirm the NumPy default path is unaffected.
Detailed implementation plan
Affected Repositories
PyAutoGalaxy (primary, only repo touched)
Work Classification
Library only. The downstream JAX-likelihood interferometer port is a separate workspace task tracked under PyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.md.
Branch Survey
| Repository |
Current Branch |
Dirty? |
./PyAutoGalaxy |
main |
clean |
Suggested branch: feature/analysis-interferometer-pytree
Worktree root: ~/Code/PyAutoLabs-wt/analysis-interferometer-pytree/ (created by /start_library)
Implementation Steps
-
Extract the shared Galaxies pytree helper. Add a new module autogalaxy/analysis/jax_pytrees.py exposing a single register_galaxies_pytree() function. Move the _flatten_galaxies / _unflatten_galaxies block currently inlined in autogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees (lines 195-208) into the helper. Idempotency is preserved via the existing _pytree_registered_classes guard from autoarray.abstract_ndarray.
-
Update AnalysisImaging to call the shared helper. Replace the inline Galaxies registration block in _register_fit_imaging_pytrees with a single register_galaxies_pytree() call. The FitImaging and DatasetModel registrations stay inline.
-
Add registration to AnalysisInterferometer (autogalaxy/interferometer/model/analysis.py). Add a static method:
@staticmethod
def _register_fit_interferometer_pytrees() -> None:
from autoarray.abstract_ndarray import register_instance_pytree
from autoarray.dataset.dataset_model import DatasetModel
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree
register_instance_pytree(
FitInterferometer,
no_flatten=("dataset", "adapt_images", "settings"),
)
register_instance_pytree(DatasetModel)
register_galaxies_pytree()
Then in fit_from, add the gate before any other work:
if self._use_jax:
self._register_fit_interferometer_pytrees()
Mirrors autogalaxy/imaging/model/analysis.py:146-147.
-
Verify pytest test_autogalaxy/ continues to pass. The default NumPy path is unaffected because the registrations are guarded behind self._use_jax. No unit tests should need changes.
-
Manual JIT smoke check (optional, local). Confirm jax.jit(analysis.fit_from)(instance).figure_of_merit round-trips for a small synthetic interferometer fit. This is not required to ship — the formal end-to-end test lives in the downstream JAX-likelihood port — but a quick local check is cheap insurance against typos.
Spawn-off awareness
Per the prompt's guidance: if any profile / dataset class reachable from FitInterferometer turns out not to be pytree-friendly (e.g. autoarray helpers used by transformer), stop and open a per-class registration issue. Do not add ad-hoc register_pytree_node calls inside _register_fit_interferometer_pytrees.
Key Files
PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py — new shared helper module.
PyAutoGalaxy/autogalaxy/imaging/model/analysis.py — call shared helper in _register_fit_imaging_pytrees.
PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py — add _register_fit_interferometer_pytrees and the _use_jax gate.
Reference: PyAutoLens equivalent
PyAutoLens/autolens/interferometer/model/analysis.py:177, 196-214 — _register_fit_interferometer_pytrees registers FitInterferometer, Tracer (no_flatten=("cosmology",)), DatasetModel. The autogalaxy version registers Galaxies instead of Tracer.
Original Prompt
Click to expand starting prompt
After PyAutoGalaxy AnalysisImaging shipped JAX pytree registration for
FitImaging + DatasetModel + Galaxies (PR #364, 2026-04-22), three sibling Analysis classes
were left without equivalent registration:
@PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py::AnalysisInterferometer
@PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py::AnalysisEllipse
@PyAutoGalaxy/autogalaxy/quantity/model/analysis.py::AnalysisQuantity
Setting use_jax_for_visualization=True on any of these will fail the moment the visualizer
returns a FitInterferometer / FitEllipse / FitQuantity from a jitted fit_from, because
the return type isn't a registered JAX pytree.
Reference patterns
PyAutoLens has the equivalent registrations for the interferometer dataset type — mirror it on
the autogalaxy side:
@PyAutoLens/autolens/interferometer/model/analysis.py::AnalysisInterferometer._register_fit_interferometer_pytrees
(lines 177, 196–214) registers FitInterferometer, DatasetModel, Tracer (no_flatten=("cosmology",)).
The PyAutoGalaxy version registers Galaxies instead of Tracer.
PyAutoGalaxy's existing autogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees
(lines 169–207) is the closest in-repo template — it registers the Galaxies list container
with explicit _flatten_galaxies / _unflatten_galaxies helpers because
register_instance_pytree alone drops list contents. Any of the three new analyses that hold
a Galaxies will need the same helper — re-export from a shared location or duplicate.
What to do
For each of AnalysisInterferometer, AnalysisEllipse, AnalysisQuantity:
- Add
_register_fit_*_pytrees() to the class. Call it from __init__ under the existing
use_jax gate (mirror autogalaxy AnalysisImaging line 147).
- Register the dataset's
Fit* class via register_instance_pytree. For the inversion-bearing
ones (Interferometer), check whether the inversion solver state needs no_flatten= — see
the PyAutoLens reference.
- Register
DatasetModel via register_instance_pytree(DatasetModel). Confirm the underlying
helper is idempotent (PyAutoLens AnalysisImaging and AnalysisInterferometer both register it,
so it is — but verify before assuming).
- Register the model container the analysis fits with —
Galaxies for Interferometer/Quantity,
the appropriate isophote container for Ellipse.
- Add a workspace_test pilot per dataset type:
autogalaxy_workspace_test/scripts/{interferometer,ellipse,quantity}/visualization_jax.py
matching the existing imaging/visualization_jax.py pattern. Each should pass
use_jax_for_visualization=True and assert fit.png is produced.
Spawn-off awareness
If any profile / dataset class reachable from the new Fit* types isn't yet pytree-friendly,
follow the established spawn-off pattern: stop, open a per-class registration issue, ship that
PR first. Do not paper over with ad-hoc register_pytree_node calls inside the workspace
script. See issued/autogalaxy_workspace_test_jax_likelihood_imaging.md for the pattern.
AnalysisEllipse may surface deeper blockers — its FitEllipse doesn't carry the same
Galaxies structure as the imaging/interferometer analyses. If registration turns out
non-trivial, scope it out into its own issue rather than forcing it into this task.
Scope boundary
- Eager-JAX path only (Path C in the original
issued/fit_imaging_pytree.md framing). Do not
attempt full jax.jit wrapping — that's gated on the Path A feasibility study.
- No production workspace adoption — the test workspace pilots are sufficient verification.
- This task assumes the visualizer dispatch fix from
autogalaxy/visualizer_fit_for_visualization_dispatch.md has shipped first; otherwise the
pilot scripts will pass without actually exercising the jit path.
Verification
JAX_ENABLE_X64=True python autogalaxy_workspace_test/scripts/{type}/visualization_jax.py
for each new pilot.
- Add the three new pilots to
autogalaxy_workspace_test/smoke_tests.txt.
Background
Original feature: complete.md entries jax-visualization and mge-jit-visualization
(both 2026-04-19). Imaging pytree registration on autogalaxy: complete.md entry for the
autogalaxy_workspace_test imaging port (2026-04-22). The _register_fit_imaging_pytrees
spawn-off was originally flagged in issued/autogalaxy_workspace_test_jax_likelihood_imaging.md
under "Pytree prerequisite — likely blocker".
Note on scope deviation from the original prompt: the prompt requested registrations for all three sibling analyses (interferometer, ellipse, quantity) plus visualizer pilots. After triage we narrowed this issue to interferometer only:
- The visualizer-pilot verification path was dropped because the autogalaxy visualizers still call
analysis.fit_from(...) directly (the fit_for_visualization dispatch fix has not shipped) — so the pilots would pass without exercising the JIT path. The downstream autogalaxy_workspace_test_jax_likelihood_interferometer task wraps jax.jit(analysis.fit_from) directly and asserts scalar parity, which is a stronger end-to-end test that doesn't depend on the dispatch fix.
AnalysisQuantity deferred — no autolens JAX-likelihood quantity equivalent, so verification needs separate design.
AnalysisEllipse deferred per the prompt's own scope-out clause (different Fit* shape).
Overview
AnalysisImaging(PyAutoGalaxy) shipped JAX pytree registration forFitImaging+DatasetModel+Galaxiesin PR #364 (2026-04-22). The siblingAnalysisInterferometerwas left without equivalent registration. Settinguse_jax_for_visualization=True(or wrappinganalysis.fit_frominjax.jitdirectly) on an interferometer fit fails the moment aFitInterferometeris returned from a JIT trace, because the return type is not a registered JAX pytree.This issue lands the registration scaffold on the autogalaxy side. End-to-end verification under
jax.jitwill follow in the downstreamautogalaxy_workspace_test_jax_likelihood_interferometertask, which ports the autolens JAX-likelihood interferometer scripts and is explicitly gated on this PR (seePyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.mdlines 21-28).AnalysisQuantityandAnalysisEllipseare out of scope for this issue:AnalysisQuantitywill get its own follow-up (no autolens equivalent yet, so the verification path needs design).AnalysisEllipseis structurally different (fit_list_fromreturnsList[FitEllipse]with noGalaxiesaggregate, and inheritsaf.Analysisdirectly rather thanAnalysisDataset). A separate per-class issue if it's ever wanted under JAX.Plan
Galaxiesflatten/unflatten helper out ofAnalysisImaging._register_fit_imaging_pytreesinto a small shared module so imaging and interferometer registerGalaxiesidentically._register_fit_interferometer_pytrees()toAnalysisInterferometer, registeringFitInterferometer(no_flatten = dataset/adapt_images/settings),DatasetModel(defensive — mirrors PyAutoLens), andGalaxies(via the shared helper).self._use_jaxfromfit_from, mirroring the imaging pattern atautogalaxy/imaging/model/analysis.py:146-147.pytest test_autogalaxy/to confirm the NumPy default path is unaffected.Detailed implementation plan
Affected Repositories
PyAutoGalaxy(primary, only repo touched)Work Classification
Library only. The downstream JAX-likelihood interferometer port is a separate workspace task tracked under
PyAutoPrompt/autogalaxy/autogalaxy_workspace_test_jax_likelihood_interferometer.md.Branch Survey
./PyAutoGalaxySuggested branch:
feature/analysis-interferometer-pytreeWorktree root:
~/Code/PyAutoLabs-wt/analysis-interferometer-pytree/(created by/start_library)Implementation Steps
Extract the shared
Galaxiespytree helper. Add a new moduleautogalaxy/analysis/jax_pytrees.pyexposing a singleregister_galaxies_pytree()function. Move the_flatten_galaxies/_unflatten_galaxiesblock currently inlined inautogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees(lines 195-208) into the helper. Idempotency is preserved via the existing_pytree_registered_classesguard fromautoarray.abstract_ndarray.Update
AnalysisImagingto call the shared helper. Replace the inlineGalaxiesregistration block in_register_fit_imaging_pytreeswith a singleregister_galaxies_pytree()call. TheFitImagingandDatasetModelregistrations stay inline.Add registration to
AnalysisInterferometer(autogalaxy/interferometer/model/analysis.py). Add a static method:Then in
fit_from, add the gate before any other work:Mirrors
autogalaxy/imaging/model/analysis.py:146-147.Verify
pytest test_autogalaxy/continues to pass. The default NumPy path is unaffected because the registrations are guarded behindself._use_jax. No unit tests should need changes.Manual JIT smoke check (optional, local). Confirm
jax.jit(analysis.fit_from)(instance).figure_of_meritround-trips for a small synthetic interferometer fit. This is not required to ship — the formal end-to-end test lives in the downstream JAX-likelihood port — but a quick local check is cheap insurance against typos.Spawn-off awareness
Per the prompt's guidance: if any profile / dataset class reachable from
FitInterferometerturns out not to be pytree-friendly (e.g. autoarray helpers used bytransformer), stop and open a per-class registration issue. Do not add ad-hocregister_pytree_nodecalls inside_register_fit_interferometer_pytrees.Key Files
PyAutoGalaxy/autogalaxy/analysis/jax_pytrees.py— new shared helper module.PyAutoGalaxy/autogalaxy/imaging/model/analysis.py— call shared helper in_register_fit_imaging_pytrees.PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py— add_register_fit_interferometer_pytreesand the_use_jaxgate.Reference: PyAutoLens equivalent
PyAutoLens/autolens/interferometer/model/analysis.py:177, 196-214—_register_fit_interferometer_pytreesregistersFitInterferometer,Tracer (no_flatten=("cosmology",)),DatasetModel. The autogalaxy version registersGalaxiesinstead ofTracer.Original Prompt
Click to expand starting prompt
After PyAutoGalaxy
AnalysisImagingshipped JAX pytree registration forFitImaging+DatasetModel+Galaxies(PR #364, 2026-04-22), three sibling Analysis classeswere left without equivalent registration:
@PyAutoGalaxy/autogalaxy/interferometer/model/analysis.py::AnalysisInterferometer@PyAutoGalaxy/autogalaxy/ellipse/model/analysis.py::AnalysisEllipse@PyAutoGalaxy/autogalaxy/quantity/model/analysis.py::AnalysisQuantitySetting
use_jax_for_visualization=Trueon any of these will fail the moment the visualizerreturns a
FitInterferometer/FitEllipse/FitQuantityfrom a jittedfit_from, becausethe return type isn't a registered JAX pytree.
Reference patterns
PyAutoLens has the equivalent registrations for the interferometer dataset type — mirror it on
the autogalaxy side:
@PyAutoLens/autolens/interferometer/model/analysis.py::AnalysisInterferometer._register_fit_interferometer_pytrees(lines 177, 196–214) registers
FitInterferometer,DatasetModel,Tracer (no_flatten=("cosmology",)).The PyAutoGalaxy version registers
Galaxiesinstead ofTracer.PyAutoGalaxy's existing
autogalaxy/imaging/model/analysis.py::_register_fit_imaging_pytrees(lines 169–207) is the closest in-repo template — it registers the
Galaxieslist containerwith explicit
_flatten_galaxies/_unflatten_galaxieshelpers becauseregister_instance_pytreealone drops list contents. Any of the three new analyses that holda
Galaxieswill need the same helper — re-export from a shared location or duplicate.What to do
For each of
AnalysisInterferometer,AnalysisEllipse,AnalysisQuantity:_register_fit_*_pytrees()to the class. Call it from__init__under the existinguse_jaxgate (mirror autogalaxyAnalysisImagingline 147).Fit*class viaregister_instance_pytree. For the inversion-bearingones (Interferometer), check whether the inversion solver state needs
no_flatten=— seethe PyAutoLens reference.
DatasetModelviaregister_instance_pytree(DatasetModel). Confirm the underlyinghelper is idempotent (PyAutoLens AnalysisImaging and AnalysisInterferometer both register it,
so it is — but verify before assuming).
Galaxiesfor Interferometer/Quantity,the appropriate isophote container for Ellipse.
autogalaxy_workspace_test/scripts/{interferometer,ellipse,quantity}/visualization_jax.pymatching the existing
imaging/visualization_jax.pypattern. Each should passuse_jax_for_visualization=Trueand assertfit.pngis produced.Spawn-off awareness
If any profile / dataset class reachable from the new
Fit*types isn't yet pytree-friendly,follow the established spawn-off pattern: stop, open a per-class registration issue, ship that
PR first. Do not paper over with ad-hoc
register_pytree_nodecalls inside the workspacescript. See
issued/autogalaxy_workspace_test_jax_likelihood_imaging.mdfor the pattern.AnalysisEllipsemay surface deeper blockers — itsFitEllipsedoesn't carry the sameGalaxiesstructure as the imaging/interferometer analyses. If registration turns outnon-trivial, scope it out into its own issue rather than forcing it into this task.
Scope boundary
issued/fit_imaging_pytree.mdframing). Do notattempt full
jax.jitwrapping — that's gated on the Path A feasibility study.autogalaxy/visualizer_fit_for_visualization_dispatch.mdhas shipped first; otherwise thepilot scripts will pass without actually exercising the jit path.
Verification
JAX_ENABLE_X64=True python autogalaxy_workspace_test/scripts/{type}/visualization_jax.pyfor each new pilot.
autogalaxy_workspace_test/smoke_tests.txt.Background
Original feature:
complete.mdentriesjax-visualizationandmge-jit-visualization(both 2026-04-19). Imaging pytree registration on autogalaxy:
complete.mdentry for theautogalaxy_workspace_test imaging port (2026-04-22). The
_register_fit_imaging_pytreesspawn-off was originally flagged in
issued/autogalaxy_workspace_test_jax_likelihood_imaging.mdunder "Pytree prerequisite — likely blocker".
Note on scope deviation from the original prompt: the prompt requested registrations for all three sibling analyses (interferometer, ellipse, quantity) plus visualizer pilots. After triage we narrowed this issue to interferometer only:
analysis.fit_from(...)directly (thefit_for_visualizationdispatch fix has not shipped) — so the pilots would pass without exercising the JIT path. The downstreamautogalaxy_workspace_test_jax_likelihood_interferometertask wrapsjax.jit(analysis.fit_from)directly and asserts scalar parity, which is a stronger end-to-end test that doesn't depend on the dispatch fix.AnalysisQuantitydeferred — no autolens JAX-likelihood quantity equivalent, so verification needs separate design.AnalysisEllipsedeferred per the prompt's own scope-out clause (differentFit*shape).