11from __future__ import annotations
22import numpy as np
3- from typing import TYPE_CHECKING , Dict , Optional , Tuple
3+ from typing import TYPE_CHECKING , Dict , List , Optional , Tuple
44
55from autoconf import conf
66from autoconf import cached_property
@@ -73,6 +73,7 @@ def __init__(
7373 galaxy_name_image_plane_mesh_grid_dict : Optional [
7474 Dict [Tuple [str , ...], aa .Grid2DIrregular ]
7575 ] = None ,
76+ galaxy_path_list : Optional [List [str ]] = None ,
7677 ):
7778 """
7879 Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
@@ -116,6 +117,14 @@ def __init__(
116117 galaxy_name_image_plane_mesh_grid_dict
117118 )
118119
120+ # Parallel to the analysis-time galaxies list (as built by
121+ # ``Analysis.galaxies_via_instance_from``). Populated by
122+ # ``updated_via_instance_from`` and used by ``image_for_galaxy`` to
123+ # recover the galaxy's path-tuple key after a JAX unflatten has produced
124+ # fresh ``Galaxy`` objects whose hashes no longer match
125+ # ``galaxy_image_dict`` keys.
126+ self .galaxy_path_list = galaxy_path_list
127+
119128 @property
120129 def mask (self ) -> aa .Mask2D :
121130 """
@@ -147,7 +156,9 @@ def model_image(self) -> aa.Array2D:
147156
148157 return adapt_model_image
149158
150- def updated_via_instance_from (self , instance , mask = None ) -> "AdaptImages" :
159+ def updated_via_instance_from (
160+ self , instance , mask = None , galaxies : Optional [List ["Galaxy" ]] = None
161+ ) -> "AdaptImages" :
151162 """
152163 Returns adapt-images which have been updated to map galaxy instances instead of galaxy names.
153164
@@ -170,13 +181,25 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
170181 mask
171182 A mask which can be applied to the adapt images, which is used when setting up the adaptive images
172183 via the aggregator and autofit database tools.
184+ galaxies
185+ Optional list of galaxies in the order used by the calling ``Analysis`` (i.e. the list passed to
186+ ``FitImaging`` / ``Tracer``). When provided, a parallel ``galaxy_path_list`` is populated so that
187+ ``image_for_galaxy`` can recover the path-tuple key for each galaxy after JAX has unflattened the
188+ galaxy instances into fresh objects. When ``None`` the path list is populated in ``path_instance_tuples_for_class``
189+ order, which matches ``Analysis.galaxies_via_instance_from`` for the common case (no
190+ ``extra_galaxies`` / ``scaling_galaxies``).
173191
174192 Returns
175193 -------
176194
177195 """
178196 from autogalaxy .galaxy .galaxy import Galaxy
179197
198+ path_by_id = {
199+ id (galaxy ): str (galaxy_name )
200+ for galaxy_name , galaxy in instance .path_instance_tuples_for_class (Galaxy )
201+ }
202+
180203 galaxy_image_dict = None
181204
182205 if self .galaxy_name_image_dict is not None :
@@ -207,7 +230,73 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
207230 self .galaxy_name_image_plane_mesh_grid_dict [galaxy_name ]
208231 )
209232
233+ if galaxies is not None :
234+ galaxy_path_list = [path_by_id .get (id (g )) for g in galaxies ]
235+ else :
236+ galaxy_path_list = [
237+ str (galaxy_name )
238+ for galaxy_name , _ in instance .path_instance_tuples_for_class (Galaxy )
239+ ]
240+
210241 return AdaptImages (
211242 galaxy_image_dict = galaxy_image_dict ,
212243 galaxy_image_plane_mesh_grid_dict = galaxy_image_plane_mesh_grid_dict ,
244+ galaxy_name_image_dict = self .galaxy_name_image_dict ,
245+ galaxy_name_image_plane_mesh_grid_dict = self .galaxy_name_image_plane_mesh_grid_dict ,
246+ galaxy_path_list = galaxy_path_list ,
213247 )
248+
249+ def image_for_galaxy (
250+ self , galaxy : "Galaxy" , galaxies : Optional [List ["Galaxy" ]] = None
251+ ) -> Optional [aa .Array2D ]:
252+ """
253+ Return the adapt image for ``galaxy``, robust to JAX ``jit`` boundaries.
254+
255+ ``galaxy_image_dict`` is keyed by the trace-time ``Galaxy`` instances. After ``jax.jit`` has flattened
256+ and unflattened a ``FitImaging``, the galaxies inside it are fresh Python objects whose ``__hash__``
257+ differs from the trace-time keys, so a direct lookup misses. This helper falls back to the path-tuple
258+ keyed ``galaxy_name_image_dict`` using ``galaxy_path_list`` to map the post-unflatten galaxy back to
259+ its trace-time path.
260+
261+ Returns ``None`` when no adapt image is associated with the galaxy.
262+ """
263+ try :
264+ return self .galaxy_image_dict [galaxy ]
265+ except (AttributeError , KeyError , TypeError ):
266+ pass
267+
268+ path = self ._path_for_galaxy (galaxy , galaxies )
269+ if path is None or self .galaxy_name_image_dict is None :
270+ return None
271+ return self .galaxy_name_image_dict .get (path )
272+
273+ def image_plane_mesh_grid_for_galaxy (
274+ self , galaxy : "Galaxy" , galaxies : Optional [List ["Galaxy" ]] = None
275+ ) -> Optional [aa .Grid2DIrregular ]:
276+ """
277+ Return the image-plane mesh grid for ``galaxy``, robust to JAX ``jit`` boundaries.
278+
279+ Companion to :meth:`image_for_galaxy` for ``galaxy_image_plane_mesh_grid_dict`` /
280+ ``galaxy_name_image_plane_mesh_grid_dict``.
281+ """
282+ try :
283+ return self .galaxy_image_plane_mesh_grid_dict [galaxy ]
284+ except (AttributeError , KeyError , TypeError ):
285+ pass
286+
287+ path = self ._path_for_galaxy (galaxy , galaxies )
288+ if path is None or self .galaxy_name_image_plane_mesh_grid_dict is None :
289+ return None
290+ return self .galaxy_name_image_plane_mesh_grid_dict .get (path )
291+
292+ def _path_for_galaxy (
293+ self , galaxy : "Galaxy" , galaxies : Optional [List ["Galaxy" ]]
294+ ) -> Optional [str ]:
295+ if not self .galaxy_path_list or galaxies is None :
296+ return None
297+ for index , candidate in enumerate (galaxies ):
298+ if candidate is galaxy :
299+ if index < len (self .galaxy_path_list ):
300+ return self .galaxy_path_list [index ]
301+ return None
302+ return None
0 commit comments