Skip to content

Commit 4991033

Browse files
Jammy2211claude
authored andcommitted
feat: register FitImaging, DatasetModel, Galaxies as pytrees in AnalysisImaging
Adds `_register_fit_imaging_pytrees` staticmethod to `AnalysisImaging`, called from `fit_from` under the existing `use_jax` gate. Mirrors autolens's `AnalysisImaging._register_fit_imaging_pytrees`. Unblocks `jax.jit(analysis.fit_from)(instance)` returning a `FitImaging` with `jax.Array` leaves on the autogalaxy imaging path. Part of PyAutoLabs/autogalaxy_workspace_test#8 (epic #5, task 3/9). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
1 parent a46023f commit 4991033

2 files changed

Lines changed: 37 additions & 0 deletions

File tree

autogalaxy/imaging/model/analysis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
141141
The fit of the galaxies to the imaging dataset, which includes the log likelihood.
142142
"""
143143

144+
if self._use_jax:
145+
self._register_fit_imaging_pytrees()
146+
144147
galaxies = self.galaxies_via_instance_from(
145148
instance=instance,
146149
)
@@ -158,6 +161,27 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
158161
xp=self._xp,
159162
)
160163

164+
@staticmethod
165+
def _register_fit_imaging_pytrees() -> None:
166+
"""Register every type reachable from a ``FitImaging`` return value
167+
so ``jax.jit(fit_from)`` can flatten its output.
168+
169+
``dataset``, ``adapt_images`` and ``settings`` are constants per
170+
analysis — ride as aux so JAX does not recurse into them. Everything
171+
else (``galaxies``, ``dataset_model`` and the autoarray wrappers they
172+
carry) is dynamic per fit.
173+
"""
174+
from autoarray.abstract_ndarray import register_instance_pytree
175+
from autoarray.dataset.dataset_model import DatasetModel
176+
from autogalaxy.galaxy.galaxies import Galaxies
177+
178+
register_instance_pytree(
179+
FitImaging,
180+
no_flatten=("dataset", "adapt_images", "settings"),
181+
)
182+
register_instance_pytree(DatasetModel)
183+
register_instance_pytree(Galaxies)
184+
161185
def save_attributes(self, paths: af.DirectoryPaths):
162186
"""
163187
Before the non-linear search begins, this routine saves attributes of the `Analysis` object to the `files`

test_autogalaxy/imaging/model/test_analysis_imaging.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,16 @@ def test__figure_of_merit__matches_correct_fit_given_galaxy_profiles(
3737
fit = ag.FitImaging(dataset=masked_imaging_7x7, galaxies=galaxies)
3838

3939
assert fit.log_likelihood == fit_figure_of_merit
40+
41+
42+
def test__register_fit_imaging_pytrees__registers_fit_galaxies_and_dataset_model():
43+
from autoarray.abstract_ndarray import _pytree_registered_classes
44+
from autoarray.dataset.dataset_model import DatasetModel
45+
from autogalaxy.galaxy.galaxies import Galaxies
46+
from autogalaxy.imaging.fit_imaging import FitImaging
47+
48+
ag.AnalysisImaging._register_fit_imaging_pytrees()
49+
50+
assert FitImaging in _pytree_registered_classes
51+
assert DatasetModel in _pytree_registered_classes
52+
assert Galaxies in _pytree_registered_classes

0 commit comments

Comments
 (0)