diff --git a/autogalaxy/analysis/adapt_images/adapt_images.py b/autogalaxy/analysis/adapt_images/adapt_images.py index 19734d8bf..b6362eadf 100644 --- a/autogalaxy/analysis/adapt_images/adapt_images.py +++ b/autogalaxy/analysis/adapt_images/adapt_images.py @@ -1,6 +1,6 @@ from __future__ import annotations import numpy as np -from typing import TYPE_CHECKING, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from autoconf import conf from autoconf import cached_property @@ -73,6 +73,7 @@ def __init__( galaxy_name_image_plane_mesh_grid_dict: Optional[ Dict[Tuple[str, ...], aa.Grid2DIrregular] ] = None, + galaxy_path_list: Optional[List[str]] = None, ): """ Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the @@ -116,6 +117,14 @@ def __init__( galaxy_name_image_plane_mesh_grid_dict ) + # Parallel to the analysis-time galaxies list (as built by + # ``Analysis.galaxies_via_instance_from``). Populated by + # ``updated_via_instance_from`` and used by ``image_for_galaxy`` to + # recover the galaxy's path-tuple key after a JAX unflatten has produced + # fresh ``Galaxy`` objects whose hashes no longer match + # ``galaxy_image_dict`` keys. + self.galaxy_path_list = galaxy_path_list + @property def mask(self) -> aa.Mask2D: """ @@ -147,7 +156,9 @@ def model_image(self) -> aa.Array2D: return adapt_model_image - def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages": + def updated_via_instance_from( + self, instance, mask=None, galaxies: Optional[List["Galaxy"]] = None + ) -> "AdaptImages": """ Returns adapt-images which have been updated to map galaxy instances instead of galaxy names. @@ -170,6 +181,13 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages": mask A mask which can be applied to the adapt images, which is used when setting up the adaptive images via the aggregator and autofit database tools. + galaxies + Optional list of galaxies in the order used by the calling ``Analysis`` (i.e. the list passed to + ``FitImaging`` / ``Tracer``). When provided, a parallel ``galaxy_path_list`` is populated so that + ``image_for_galaxy`` can recover the path-tuple key for each galaxy after JAX has unflattened the + galaxy instances into fresh objects. When ``None`` the path list is populated in ``path_instance_tuples_for_class`` + order, which matches ``Analysis.galaxies_via_instance_from`` for the common case (no + ``extra_galaxies`` / ``scaling_galaxies``). Returns ------- @@ -177,6 +195,11 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages": """ from autogalaxy.galaxy.galaxy import Galaxy + path_by_id = { + id(galaxy): str(galaxy_name) + for galaxy_name, galaxy in instance.path_instance_tuples_for_class(Galaxy) + } + galaxy_image_dict = None if self.galaxy_name_image_dict is not None: @@ -207,7 +230,73 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages": self.galaxy_name_image_plane_mesh_grid_dict[galaxy_name] ) + if galaxies is not None: + galaxy_path_list = [path_by_id.get(id(g)) for g in galaxies] + else: + galaxy_path_list = [ + str(galaxy_name) + for galaxy_name, _ in instance.path_instance_tuples_for_class(Galaxy) + ] + return AdaptImages( galaxy_image_dict=galaxy_image_dict, galaxy_image_plane_mesh_grid_dict=galaxy_image_plane_mesh_grid_dict, + galaxy_name_image_dict=self.galaxy_name_image_dict, + galaxy_name_image_plane_mesh_grid_dict=self.galaxy_name_image_plane_mesh_grid_dict, + galaxy_path_list=galaxy_path_list, ) + + def image_for_galaxy( + self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None + ) -> Optional[aa.Array2D]: + """ + Return the adapt image for ``galaxy``, robust to JAX ``jit`` boundaries. + + ``galaxy_image_dict`` is keyed by the trace-time ``Galaxy`` instances. After ``jax.jit`` has flattened + and unflattened a ``FitImaging``, the galaxies inside it are fresh Python objects whose ``__hash__`` + differs from the trace-time keys, so a direct lookup misses. This helper falls back to the path-tuple + keyed ``galaxy_name_image_dict`` using ``galaxy_path_list`` to map the post-unflatten galaxy back to + its trace-time path. + + Returns ``None`` when no adapt image is associated with the galaxy. + """ + try: + return self.galaxy_image_dict[galaxy] + except (AttributeError, KeyError, TypeError): + pass + + path = self._path_for_galaxy(galaxy, galaxies) + if path is None or self.galaxy_name_image_dict is None: + return None + return self.galaxy_name_image_dict.get(path) + + def image_plane_mesh_grid_for_galaxy( + self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None + ) -> Optional[aa.Grid2DIrregular]: + """ + Return the image-plane mesh grid for ``galaxy``, robust to JAX ``jit`` boundaries. + + Companion to :meth:`image_for_galaxy` for ``galaxy_image_plane_mesh_grid_dict`` / + ``galaxy_name_image_plane_mesh_grid_dict``. + """ + try: + return self.galaxy_image_plane_mesh_grid_dict[galaxy] + except (AttributeError, KeyError, TypeError): + pass + + path = self._path_for_galaxy(galaxy, galaxies) + if path is None or self.galaxy_name_image_plane_mesh_grid_dict is None: + return None + return self.galaxy_name_image_plane_mesh_grid_dict.get(path) + + def _path_for_galaxy( + self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] + ) -> Optional[str]: + if not self.galaxy_path_list or galaxies is None: + return None + for index, candidate in enumerate(galaxies): + if candidate is galaxy: + if index < len(self.galaxy_path_list): + return self.galaxy_path_list[index] + return None + return None diff --git a/autogalaxy/analysis/analysis/dataset.py b/autogalaxy/analysis/analysis/dataset.py index 5d000cd3d..d4a932646 100644 --- a/autogalaxy/analysis/analysis/dataset.py +++ b/autogalaxy/analysis/analysis/dataset.py @@ -169,8 +169,14 @@ def save_results(self, paths: af.DirectoryPaths, result: ResultDataset): except AttributeError: pass - def adapt_images_via_instance_from(self, instance: af.ModelInstance) -> AdaptImages: + def adapt_images_via_instance_from( + self, + instance: af.ModelInstance, + galaxies=None, + ) -> AdaptImages: try: - return self.adapt_images.updated_via_instance_from(instance=instance) + return self.adapt_images.updated_via_instance_from( + instance=instance, galaxies=galaxies + ) except AttributeError: pass diff --git a/autogalaxy/galaxy/to_inversion.py b/autogalaxy/galaxy/to_inversion.py index 6055aca31..190fcd951 100644 --- a/autogalaxy/galaxy/to_inversion.py +++ b/autogalaxy/galaxy/to_inversion.py @@ -213,6 +213,7 @@ def __init__( adapt_images: Optional[AdaptImages] = None, settings: aa.Settings = None, xp=np, + path_galaxies: Optional[List[Galaxy]] = None, ): """ Interfaces a dataset and input list of galaxies with the inversion module. to setup a @@ -247,8 +248,13 @@ def __init__( the pixelization's pixels to the brightest regions of the image. settings The settings of the inversion, which controls how the linear algebra calculation is performed. + path_galaxies + The full ordered list of galaxies that ``adapt_images.galaxy_path_list`` was aligned with (typically + ``tracer.galaxies`` in autolens, where ``galaxies`` is a per-plane subset). When ``None`` defaults to + ``galaxies`` — correct for the single-plane autogalaxy case. """ self.galaxies = Galaxies(galaxies) + self.path_galaxies = path_galaxies if path_galaxies is not None else self.galaxies super().__init__( dataset=dataset, @@ -418,28 +424,12 @@ def image_plane_mesh_grid_list( image_plane_mesh_grid_list = [] for galaxy in self.galaxies.galaxies_with_cls_list_from(cls=aa.Pixelization): - try: - image_plane_mesh_grid = ( - self.adapt_images.galaxy_image_plane_mesh_grid_dict[galaxy] - ) - except (AttributeError, KeyError, TypeError): + if self.adapt_images is None: image_plane_mesh_grid = None - - if image_plane_mesh_grid is None: - # Fallback for JAX JIT: after jax.jit unflatten the galaxy instances - # stored as keys in ``adapt_images`` are stale (different Python objects - # from those in the current tracer), so the dict key lookup above fails - # and yields None. When the dict contains exactly one mesh-grid entry, - # take that single value by insertion order — this is always correct in - # the one-pixelised-source case (Delaunay/Hilbert image-mesh fits). - # fit-imaging-pytree-delaunay - try: - dict_ = self.adapt_images.galaxy_image_plane_mesh_grid_dict - vals = list(dict_.values()) if dict_ else [] - if len(vals) == 1: - image_plane_mesh_grid = vals[0] - except (AttributeError, TypeError, KeyError): - pass + else: + image_plane_mesh_grid = self.adapt_images.image_plane_mesh_grid_for_galaxy( + galaxy, self.path_galaxies + ) image_plane_mesh_grid_list.append(image_plane_mesh_grid) @@ -551,10 +541,12 @@ def mapper_galaxy_dict(self) -> Dict[aa.Mapper, Galaxy]: for mapper_index in range(len(mesh_grid_list)): galaxy = galaxies_with_pixelization_list[mapper_index] - try: - adapt_galaxy_image = self.adapt_images.galaxy_image_dict[galaxy] - except (AttributeError, KeyError): + if self.adapt_images is None: adapt_galaxy_image = None + else: + adapt_galaxy_image = self.adapt_images.image_for_galaxy( + galaxy, self.path_galaxies + ) mapper = self.mapper_from( mesh=pixelization_list[mapper_index].mesh, diff --git a/autogalaxy/imaging/model/analysis.py b/autogalaxy/imaging/model/analysis.py index be584b2bf..d6610a32f 100644 --- a/autogalaxy/imaging/model/analysis.py +++ b/autogalaxy/imaging/model/analysis.py @@ -152,7 +152,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging: dataset_model = self.dataset_model_via_instance_from(instance=instance) - adapt_images = self.adapt_images_via_instance_from(instance=instance) + adapt_images = self.adapt_images_via_instance_from( + instance=instance, galaxies=galaxies + ) return FitImaging( dataset=self.dataset, diff --git a/autogalaxy/interferometer/model/analysis.py b/autogalaxy/interferometer/model/analysis.py index 15db5f6d1..d45e6da6a 100644 --- a/autogalaxy/interferometer/model/analysis.py +++ b/autogalaxy/interferometer/model/analysis.py @@ -146,7 +146,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer: instance=instance, ) - adapt_images = self.adapt_images_via_instance_from(instance=instance) + adapt_images = self.adapt_images_via_instance_from( + instance=instance, galaxies=galaxies + ) return FitInterferometer( dataset=self.dataset, diff --git a/test_autogalaxy/analysis/test_adapt_images.py b/test_autogalaxy/analysis/test_adapt_images.py index 12add0ea5..ab9b0aad6 100644 --- a/test_autogalaxy/analysis/test_adapt_images.py +++ b/test_autogalaxy/analysis/test_adapt_images.py @@ -1,6 +1,7 @@ import pytest import numpy as np +import autofit as af import autogalaxy as ag @@ -20,3 +21,88 @@ def test__instance_with_associated_adapt_images_from(masked_imaging_7x7): assert adapt_images.model_image.native == pytest.approx( 3.0 * np.ones((3, 3)), 1.0e-4 ) + + +def test__image_for_galaxy__resolves_after_galaxy_identity_changes(): + """ + Simulates the post-``jax.jit`` unflatten boundary: ``adapt_images.galaxy_image_dict`` is keyed by the + trace-time ``Galaxy`` instances, but the lookup at ``GalaxiesToInversion.mapper_galaxy_dict`` is performed + against fresh ``Galaxy`` objects whose ``__hash__`` differs. The path-tuple lookup via + ``galaxy_name_image_dict`` must still resolve to the right adapt image. + """ + galaxies = af.ModelInstance() + galaxies.lens = ag.Galaxy(redshift=0.5) + galaxies.source = ag.Galaxy(redshift=1.0) + + instance = af.ModelInstance() + instance.galaxies = galaxies + + galaxy_name_image_dict = { + str(("galaxies", "lens")): ag.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0), + str(("galaxies", "source")): ag.Array2D.full( + fill_value=2.0, shape_native=(3, 3), pixel_scales=1.0 + ), + } + + trace_galaxies = [galaxies.lens, galaxies.source] + + adapt_images = ag.AdaptImages( + galaxy_name_image_dict=galaxy_name_image_dict, + ).updated_via_instance_from(instance=instance, galaxies=trace_galaxies) + + assert adapt_images.galaxy_path_list == [ + str(("galaxies", "lens")), + str(("galaxies", "source")), + ] + + # Fast path: by-instance lookup still works for the trace-time galaxies. + assert adapt_images.image_for_galaxy( + trace_galaxies[0], trace_galaxies + ).native == pytest.approx(np.ones((3, 3)), 1.0e-4) + + # Simulate post-unflatten: fresh ``Galaxy`` objects with new ``.id`` values + # placed at the same positions as the trace-time list. ``galaxy_image_dict`` + # cannot resolve them (hash mismatch) so the helper must fall back to + # ``galaxy_name_image_dict`` via ``galaxy_path_list``. + fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)] + + assert adapt_images.galaxy_image_dict.get(fresh_galaxies[0]) is None + assert adapt_images.image_for_galaxy( + fresh_galaxies[0], fresh_galaxies + ).native == pytest.approx(np.ones((3, 3)), 1.0e-4) + assert adapt_images.image_for_galaxy( + fresh_galaxies[1], fresh_galaxies + ).native == pytest.approx(2.0 * np.ones((3, 3)), 1.0e-4) + + +def test__image_plane_mesh_grid_for_galaxy__resolves_after_galaxy_identity_changes(): + """ + Companion to :func:`test__image_for_galaxy__resolves_after_galaxy_identity_changes` for the mesh-grid + lookup path used by ``GalaxiesToInversion.image_plane_mesh_grid_list``. + """ + galaxies = af.ModelInstance() + galaxies.lens = ag.Galaxy(redshift=0.5) + galaxies.source = ag.Galaxy(redshift=1.0) + + instance = af.ModelInstance() + instance.galaxies = galaxies + + galaxy_name_image_plane_mesh_grid_dict = { + str(("galaxies", "lens")): ag.Grid2DIrregular(values=[(3.0, 3.0), (3.0, 3.0)]), + str(("galaxies", "source")): ag.Grid2DIrregular(values=[(4.0, 4.0), (4.0, 4.0)]), + } + + trace_galaxies = [galaxies.lens, galaxies.source] + + adapt_images = ag.AdaptImages( + galaxy_name_image_plane_mesh_grid_dict=galaxy_name_image_plane_mesh_grid_dict, + ).updated_via_instance_from(instance=instance, galaxies=trace_galaxies) + + fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)] + + assert adapt_images.image_plane_mesh_grid_for_galaxy( + fresh_galaxies[0], fresh_galaxies + ) == pytest.approx(3.0 * np.ones((2, 2)), 1.0e-4) + assert adapt_images.image_plane_mesh_grid_for_galaxy( + fresh_galaxies[1], fresh_galaxies + ) == pytest.approx(4.0 * np.ones((2, 2)), 1.0e-4)