Skip to content

Commit 20699a2

Browse files
Jammy2211claude
authored andcommitted
fix: AdaptImages galaxy-identity mismatch across jax.jit boundary
Resolves dict-keyed-by-Galaxy-instance crash for adapt-image models after jax.jit unflatten produces fresh Galaxy objects with new .id values. Adds galaxy_path_list parallel to the analysis-time galaxies list and two helpers (image_for_galaxy, image_plane_mesh_grid_for_galaxy) that fall back to path-tuple keying via galaxy_name_image_dict. Drops the single-mesh-grid fallback at to_inversion.py:428-442 — replaced by the proper fix. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent af41a68 commit 20699a2

6 files changed

Lines changed: 207 additions & 30 deletions

File tree

autogalaxy/analysis/adapt_images/adapt_images.py

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22
import numpy as np
3-
from typing import TYPE_CHECKING, Dict, Optional, Tuple
3+
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
44

55
from autoconf import conf
66
from autoconf import cached_property
@@ -73,6 +73,7 @@ def __init__(
7373
galaxy_name_image_plane_mesh_grid_dict: Optional[
7474
Dict[Tuple[str, ...], aa.Grid2DIrregular]
7575
] = None,
76+
galaxy_path_list: Optional[List[str]] = None,
7677
):
7778
"""
7879
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
@@ -116,6 +117,14 @@ def __init__(
116117
galaxy_name_image_plane_mesh_grid_dict
117118
)
118119

120+
# Parallel to the analysis-time galaxies list (as built by
121+
# ``Analysis.galaxies_via_instance_from``). Populated by
122+
# ``updated_via_instance_from`` and used by ``image_for_galaxy`` to
123+
# recover the galaxy's path-tuple key after a JAX unflatten has produced
124+
# fresh ``Galaxy`` objects whose hashes no longer match
125+
# ``galaxy_image_dict`` keys.
126+
self.galaxy_path_list = galaxy_path_list
127+
119128
@property
120129
def mask(self) -> aa.Mask2D:
121130
"""
@@ -147,7 +156,9 @@ def model_image(self) -> aa.Array2D:
147156

148157
return adapt_model_image
149158

150-
def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
159+
def updated_via_instance_from(
160+
self, instance, mask=None, galaxies: Optional[List["Galaxy"]] = None
161+
) -> "AdaptImages":
151162
"""
152163
Returns adapt-images which have been updated to map galaxy instances instead of galaxy names.
153164
@@ -170,13 +181,25 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
170181
mask
171182
A mask which can be applied to the adapt images, which is used when setting up the adaptive images
172183
via the aggregator and autofit database tools.
184+
galaxies
185+
Optional list of galaxies in the order used by the calling ``Analysis`` (i.e. the list passed to
186+
``FitImaging`` / ``Tracer``). When provided, a parallel ``galaxy_path_list`` is populated so that
187+
``image_for_galaxy`` can recover the path-tuple key for each galaxy after JAX has unflattened the
188+
galaxy instances into fresh objects. When ``None`` the path list is populated in ``path_instance_tuples_for_class``
189+
order, which matches ``Analysis.galaxies_via_instance_from`` for the common case (no
190+
``extra_galaxies`` / ``scaling_galaxies``).
173191
174192
Returns
175193
-------
176194
177195
"""
178196
from autogalaxy.galaxy.galaxy import Galaxy
179197

198+
path_by_id = {
199+
id(galaxy): str(galaxy_name)
200+
for galaxy_name, galaxy in instance.path_instance_tuples_for_class(Galaxy)
201+
}
202+
180203
galaxy_image_dict = None
181204

182205
if self.galaxy_name_image_dict is not None:
@@ -207,7 +230,73 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
207230
self.galaxy_name_image_plane_mesh_grid_dict[galaxy_name]
208231
)
209232

233+
if galaxies is not None:
234+
galaxy_path_list = [path_by_id.get(id(g)) for g in galaxies]
235+
else:
236+
galaxy_path_list = [
237+
str(galaxy_name)
238+
for galaxy_name, _ in instance.path_instance_tuples_for_class(Galaxy)
239+
]
240+
210241
return AdaptImages(
211242
galaxy_image_dict=galaxy_image_dict,
212243
galaxy_image_plane_mesh_grid_dict=galaxy_image_plane_mesh_grid_dict,
244+
galaxy_name_image_dict=self.galaxy_name_image_dict,
245+
galaxy_name_image_plane_mesh_grid_dict=self.galaxy_name_image_plane_mesh_grid_dict,
246+
galaxy_path_list=galaxy_path_list,
213247
)
248+
249+
def image_for_galaxy(
250+
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None
251+
) -> Optional[aa.Array2D]:
252+
"""
253+
Return the adapt image for ``galaxy``, robust to JAX ``jit`` boundaries.
254+
255+
``galaxy_image_dict`` is keyed by the trace-time ``Galaxy`` instances. After ``jax.jit`` has flattened
256+
and unflattened a ``FitImaging``, the galaxies inside it are fresh Python objects whose ``__hash__``
257+
differs from the trace-time keys, so a direct lookup misses. This helper falls back to the path-tuple
258+
keyed ``galaxy_name_image_dict`` using ``galaxy_path_list`` to map the post-unflatten galaxy back to
259+
its trace-time path.
260+
261+
Returns ``None`` when no adapt image is associated with the galaxy.
262+
"""
263+
try:
264+
return self.galaxy_image_dict[galaxy]
265+
except (AttributeError, KeyError, TypeError):
266+
pass
267+
268+
path = self._path_for_galaxy(galaxy, galaxies)
269+
if path is None or self.galaxy_name_image_dict is None:
270+
return None
271+
return self.galaxy_name_image_dict.get(path)
272+
273+
def image_plane_mesh_grid_for_galaxy(
274+
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None
275+
) -> Optional[aa.Grid2DIrregular]:
276+
"""
277+
Return the image-plane mesh grid for ``galaxy``, robust to JAX ``jit`` boundaries.
278+
279+
Companion to :meth:`image_for_galaxy` for ``galaxy_image_plane_mesh_grid_dict`` /
280+
``galaxy_name_image_plane_mesh_grid_dict``.
281+
"""
282+
try:
283+
return self.galaxy_image_plane_mesh_grid_dict[galaxy]
284+
except (AttributeError, KeyError, TypeError):
285+
pass
286+
287+
path = self._path_for_galaxy(galaxy, galaxies)
288+
if path is None or self.galaxy_name_image_plane_mesh_grid_dict is None:
289+
return None
290+
return self.galaxy_name_image_plane_mesh_grid_dict.get(path)
291+
292+
def _path_for_galaxy(
293+
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]]
294+
) -> Optional[str]:
295+
if not self.galaxy_path_list or galaxies is None:
296+
return None
297+
for index, candidate in enumerate(galaxies):
298+
if candidate is galaxy:
299+
if index < len(self.galaxy_path_list):
300+
return self.galaxy_path_list[index]
301+
return None
302+
return None

autogalaxy/analysis/analysis/dataset.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,14 @@ def save_results(self, paths: af.DirectoryPaths, result: ResultDataset):
169169
except AttributeError:
170170
pass
171171

172-
def adapt_images_via_instance_from(self, instance: af.ModelInstance) -> AdaptImages:
172+
def adapt_images_via_instance_from(
173+
self,
174+
instance: af.ModelInstance,
175+
galaxies=None,
176+
) -> AdaptImages:
173177
try:
174-
return self.adapt_images.updated_via_instance_from(instance=instance)
178+
return self.adapt_images.updated_via_instance_from(
179+
instance=instance, galaxies=galaxies
180+
)
175181
except AttributeError:
176182
pass

autogalaxy/galaxy/to_inversion.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def __init__(
213213
adapt_images: Optional[AdaptImages] = None,
214214
settings: aa.Settings = None,
215215
xp=np,
216+
path_galaxies: Optional[List[Galaxy]] = None,
216217
):
217218
"""
218219
Interfaces a dataset and input list of galaxies with the inversion module. to setup a
@@ -247,8 +248,13 @@ def __init__(
247248
the pixelization's pixels to the brightest regions of the image.
248249
settings
249250
The settings of the inversion, which controls how the linear algebra calculation is performed.
251+
path_galaxies
252+
The full ordered list of galaxies that ``adapt_images.galaxy_path_list`` was aligned with (typically
253+
``tracer.galaxies`` in autolens, where ``galaxies`` is a per-plane subset). When ``None`` defaults to
254+
``galaxies`` — correct for the single-plane autogalaxy case.
250255
"""
251256
self.galaxies = Galaxies(galaxies)
257+
self.path_galaxies = path_galaxies if path_galaxies is not None else self.galaxies
252258

253259
super().__init__(
254260
dataset=dataset,
@@ -418,28 +424,12 @@ def image_plane_mesh_grid_list(
418424
image_plane_mesh_grid_list = []
419425

420426
for galaxy in self.galaxies.galaxies_with_cls_list_from(cls=aa.Pixelization):
421-
try:
422-
image_plane_mesh_grid = (
423-
self.adapt_images.galaxy_image_plane_mesh_grid_dict[galaxy]
424-
)
425-
except (AttributeError, KeyError, TypeError):
427+
if self.adapt_images is None:
426428
image_plane_mesh_grid = None
427-
428-
if image_plane_mesh_grid is None:
429-
# Fallback for JAX JIT: after jax.jit unflatten the galaxy instances
430-
# stored as keys in ``adapt_images`` are stale (different Python objects
431-
# from those in the current tracer), so the dict key lookup above fails
432-
# and yields None. When the dict contains exactly one mesh-grid entry,
433-
# take that single value by insertion order — this is always correct in
434-
# the one-pixelised-source case (Delaunay/Hilbert image-mesh fits).
435-
# fit-imaging-pytree-delaunay
436-
try:
437-
dict_ = self.adapt_images.galaxy_image_plane_mesh_grid_dict
438-
vals = list(dict_.values()) if dict_ else []
439-
if len(vals) == 1:
440-
image_plane_mesh_grid = vals[0]
441-
except (AttributeError, TypeError, KeyError):
442-
pass
429+
else:
430+
image_plane_mesh_grid = self.adapt_images.image_plane_mesh_grid_for_galaxy(
431+
galaxy, self.path_galaxies
432+
)
443433

444434
image_plane_mesh_grid_list.append(image_plane_mesh_grid)
445435

@@ -551,10 +541,12 @@ def mapper_galaxy_dict(self) -> Dict[aa.Mapper, Galaxy]:
551541
for mapper_index in range(len(mesh_grid_list)):
552542
galaxy = galaxies_with_pixelization_list[mapper_index]
553543

554-
try:
555-
adapt_galaxy_image = self.adapt_images.galaxy_image_dict[galaxy]
556-
except (AttributeError, KeyError):
544+
if self.adapt_images is None:
557545
adapt_galaxy_image = None
546+
else:
547+
adapt_galaxy_image = self.adapt_images.image_for_galaxy(
548+
galaxy, self.path_galaxies
549+
)
558550

559551
mapper = self.mapper_from(
560552
mesh=pixelization_list[mapper_index].mesh,

autogalaxy/imaging/model/analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:
152152

153153
dataset_model = self.dataset_model_via_instance_from(instance=instance)
154154

155-
adapt_images = self.adapt_images_via_instance_from(instance=instance)
155+
adapt_images = self.adapt_images_via_instance_from(
156+
instance=instance, galaxies=galaxies
157+
)
156158

157159
return FitImaging(
158160
dataset=self.dataset,

autogalaxy/interferometer/model/analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
146146
instance=instance,
147147
)
148148

149-
adapt_images = self.adapt_images_via_instance_from(instance=instance)
149+
adapt_images = self.adapt_images_via_instance_from(
150+
instance=instance, galaxies=galaxies
151+
)
150152

151153
return FitInterferometer(
152154
dataset=self.dataset,

test_autogalaxy/analysis/test_adapt_images.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
import numpy as np
33

4+
import autofit as af
45
import autogalaxy as ag
56

67

@@ -20,3 +21,88 @@ def test__instance_with_associated_adapt_images_from(masked_imaging_7x7):
2021
assert adapt_images.model_image.native == pytest.approx(
2122
3.0 * np.ones((3, 3)), 1.0e-4
2223
)
24+
25+
26+
def test__image_for_galaxy__resolves_after_galaxy_identity_changes():
27+
"""
28+
Simulates the post-``jax.jit`` unflatten boundary: ``adapt_images.galaxy_image_dict`` is keyed by the
29+
trace-time ``Galaxy`` instances, but the lookup at ``GalaxiesToInversion.mapper_galaxy_dict`` is performed
30+
against fresh ``Galaxy`` objects whose ``__hash__`` differs. The path-tuple lookup via
31+
``galaxy_name_image_dict`` must still resolve to the right adapt image.
32+
"""
33+
galaxies = af.ModelInstance()
34+
galaxies.lens = ag.Galaxy(redshift=0.5)
35+
galaxies.source = ag.Galaxy(redshift=1.0)
36+
37+
instance = af.ModelInstance()
38+
instance.galaxies = galaxies
39+
40+
galaxy_name_image_dict = {
41+
str(("galaxies", "lens")): ag.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0),
42+
str(("galaxies", "source")): ag.Array2D.full(
43+
fill_value=2.0, shape_native=(3, 3), pixel_scales=1.0
44+
),
45+
}
46+
47+
trace_galaxies = [galaxies.lens, galaxies.source]
48+
49+
adapt_images = ag.AdaptImages(
50+
galaxy_name_image_dict=galaxy_name_image_dict,
51+
).updated_via_instance_from(instance=instance, galaxies=trace_galaxies)
52+
53+
assert adapt_images.galaxy_path_list == [
54+
str(("galaxies", "lens")),
55+
str(("galaxies", "source")),
56+
]
57+
58+
# Fast path: by-instance lookup still works for the trace-time galaxies.
59+
assert adapt_images.image_for_galaxy(
60+
trace_galaxies[0], trace_galaxies
61+
).native == pytest.approx(np.ones((3, 3)), 1.0e-4)
62+
63+
# Simulate post-unflatten: fresh ``Galaxy`` objects with new ``.id`` values
64+
# placed at the same positions as the trace-time list. ``galaxy_image_dict``
65+
# cannot resolve them (hash mismatch) so the helper must fall back to
66+
# ``galaxy_name_image_dict`` via ``galaxy_path_list``.
67+
fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)]
68+
69+
assert adapt_images.galaxy_image_dict.get(fresh_galaxies[0]) is None
70+
assert adapt_images.image_for_galaxy(
71+
fresh_galaxies[0], fresh_galaxies
72+
).native == pytest.approx(np.ones((3, 3)), 1.0e-4)
73+
assert adapt_images.image_for_galaxy(
74+
fresh_galaxies[1], fresh_galaxies
75+
).native == pytest.approx(2.0 * np.ones((3, 3)), 1.0e-4)
76+
77+
78+
def test__image_plane_mesh_grid_for_galaxy__resolves_after_galaxy_identity_changes():
79+
"""
80+
Companion to :func:`test__image_for_galaxy__resolves_after_galaxy_identity_changes` for the mesh-grid
81+
lookup path used by ``GalaxiesToInversion.image_plane_mesh_grid_list``.
82+
"""
83+
galaxies = af.ModelInstance()
84+
galaxies.lens = ag.Galaxy(redshift=0.5)
85+
galaxies.source = ag.Galaxy(redshift=1.0)
86+
87+
instance = af.ModelInstance()
88+
instance.galaxies = galaxies
89+
90+
galaxy_name_image_plane_mesh_grid_dict = {
91+
str(("galaxies", "lens")): ag.Grid2DIrregular(values=[(3.0, 3.0), (3.0, 3.0)]),
92+
str(("galaxies", "source")): ag.Grid2DIrregular(values=[(4.0, 4.0), (4.0, 4.0)]),
93+
}
94+
95+
trace_galaxies = [galaxies.lens, galaxies.source]
96+
97+
adapt_images = ag.AdaptImages(
98+
galaxy_name_image_plane_mesh_grid_dict=galaxy_name_image_plane_mesh_grid_dict,
99+
).updated_via_instance_from(instance=instance, galaxies=trace_galaxies)
100+
101+
fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)]
102+
103+
assert adapt_images.image_plane_mesh_grid_for_galaxy(
104+
fresh_galaxies[0], fresh_galaxies
105+
) == pytest.approx(3.0 * np.ones((2, 2)), 1.0e-4)
106+
assert adapt_images.image_plane_mesh_grid_for_galaxy(
107+
fresh_galaxies[1], fresh_galaxies
108+
) == pytest.approx(4.0 * np.ones((2, 2)), 1.0e-4)

0 commit comments

Comments
 (0)