Skip to content

Commit ec9161a

Browse files
Jammy2211claude
authored andcommitted
feat(interferometer/jax): register pytrees for AnalysisInterferometer
Mirror AnalysisImaging's pytree registration on the interferometer side so jax.jit(fit_from) can flatten its FitInterferometer return value. Extract the Galaxies flatten/unflatten block (~12 lines, identical across analyses) into autogalaxy.analysis.jax_pytrees.register_galaxies_pytree() so imaging and interferometer share the non-trivial logic without duplication. End-to-end JIT verification (jax.jit(analysis.fit_from) round-trip with NumPy parity) will land in the downstream autogalaxy_workspace_test_jax_likelihood_interferometer task, which is explicitly gated on this PR. Refs #375 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 0dcea47 commit ec9161a

3 files changed

Lines changed: 69 additions & 24 deletions

File tree

autogalaxy/analysis/jax_pytrees.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Shared JAX pytree registrations for autogalaxy analysis classes.
2+
3+
Each ``Analysis*`` class registers its own ``Fit*`` and per-analysis
4+
constants inline (so the call site stays self-documenting), but the
5+
``Galaxies`` registration is shared from here because the custom
6+
flatten/unflatten logic is non-trivial and identical across all
7+
analyses that hold a ``Galaxies`` aggregate.
8+
"""
9+
10+
11+
def register_galaxies_pytree() -> None:
12+
"""Register ``Galaxies`` as a JAX pytree.
13+
14+
``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten
15+
in ``register_instance_pytree`` would drop the list contents. This
16+
registers a custom flatten that carries the list items as dynamic
17+
children and the ``__dict__`` entries as aux.
18+
19+
Idempotent — guarded by ``_pytree_registered_classes`` so repeated
20+
calls (e.g. from each ``Analysis*.fit_from``) are cheap.
21+
"""
22+
from autoarray.abstract_ndarray import _pytree_registered_classes
23+
from autoconf.jax_wrapper import register_pytree_node
24+
from autogalaxy.galaxy.galaxies import Galaxies
25+
26+
if Galaxies in _pytree_registered_classes:
27+
return
28+
29+
def _flatten_galaxies(galaxies):
30+
dict_items = tuple(sorted(galaxies.__dict__.items()))
31+
return tuple(galaxies), dict_items
32+
33+
def _unflatten_galaxies(aux, children):
34+
new = Galaxies.__new__(Galaxies)
35+
list.__init__(new, children)
36+
for key, value in aux:
37+
setattr(new, key, value)
38+
return new
39+
40+
register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies)
41+
_pytree_registered_classes.add(Galaxies)

autogalaxy/imaging/model/analysis.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -175,37 +175,16 @@ def _register_fit_imaging_pytrees() -> None:
175175
else (``galaxies``, ``dataset_model`` and the autoarray wrappers they
176176
carry) is dynamic per fit.
177177
"""
178-
from autoarray.abstract_ndarray import (
179-
_pytree_registered_classes,
180-
register_instance_pytree,
181-
)
178+
from autoarray.abstract_ndarray import register_instance_pytree
182179
from autoarray.dataset.dataset_model import DatasetModel
183-
from autoconf.jax_wrapper import register_pytree_node
184-
from autogalaxy.galaxy.galaxies import Galaxies
180+
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree
185181

186182
register_instance_pytree(
187183
FitImaging,
188184
no_flatten=("dataset", "adapt_images", "settings"),
189185
)
190186
register_instance_pytree(DatasetModel)
191-
192-
# ``Galaxies`` is a ``list`` subclass — the generic ``__dict__`` flatten
193-
# in ``register_instance_pytree`` would drop the list contents. Register
194-
# a custom flatten that carries the list items as dynamic children.
195-
if Galaxies not in _pytree_registered_classes:
196-
def _flatten_galaxies(galaxies):
197-
dict_items = tuple(sorted(galaxies.__dict__.items()))
198-
return tuple(galaxies), dict_items
199-
200-
def _unflatten_galaxies(aux, children):
201-
new = Galaxies.__new__(Galaxies)
202-
list.__init__(new, children)
203-
for key, value in aux:
204-
setattr(new, key, value)
205-
return new
206-
207-
register_pytree_node(Galaxies, _flatten_galaxies, _unflatten_galaxies)
208-
_pytree_registered_classes.add(Galaxies)
187+
register_galaxies_pytree()
209188

210189
def save_attributes(self, paths: af.DirectoryPaths):
211190
"""

autogalaxy/interferometer/model/analysis.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
142142
FitInterferometer
143143
The fit of the galaxies to the interferometer dataset, which includes the log likelihood.
144144
"""
145+
146+
if self._use_jax:
147+
self._register_fit_interferometer_pytrees()
148+
145149
galaxies = self.galaxies_via_instance_from(
146150
instance=instance,
147151
)
@@ -158,6 +162,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
158162
xp=self._xp,
159163
)
160164

165+
@staticmethod
166+
def _register_fit_interferometer_pytrees() -> None:
167+
"""Register every type reachable from a ``FitInterferometer`` return
168+
value so ``jax.jit(fit_from)`` can flatten its output.
169+
170+
``dataset``, ``adapt_images`` and ``settings`` are constants per
171+
analysis — ride as aux so JAX does not recurse into them. Everything
172+
else (``galaxies`` and the autoarray wrappers it carries) is dynamic
173+
per fit.
174+
"""
175+
from autoarray.abstract_ndarray import register_instance_pytree
176+
from autoarray.dataset.dataset_model import DatasetModel
177+
from autogalaxy.analysis.jax_pytrees import register_galaxies_pytree
178+
179+
register_instance_pytree(
180+
FitInterferometer,
181+
no_flatten=("dataset", "adapt_images", "settings"),
182+
)
183+
register_instance_pytree(DatasetModel)
184+
register_galaxies_pytree()
185+
161186
def save_attributes(self, paths: af.DirectoryPaths):
162187
"""
163188
Before the model-fit begins, this routine saves attributes of the `Analysis` object to the `files` folder

0 commit comments

Comments
 (0)