Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions autogalaxy/analysis/adapt_images/adapt_images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
import numpy as np
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

from autoconf import conf
from autoconf import cached_property
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
galaxy_name_image_plane_mesh_grid_dict: Optional[
Dict[Tuple[str, ...], aa.Grid2DIrregular]
] = None,
galaxy_path_list: Optional[List[str]] = None,
):
"""
Contains the adapt-images which are used to make a pixelization's mesh and regularization adapt to the
Expand Down Expand Up @@ -116,6 +117,14 @@ def __init__(
galaxy_name_image_plane_mesh_grid_dict
)

# Parallel to the analysis-time galaxies list (as built by
# ``Analysis.galaxies_via_instance_from``). Populated by
# ``updated_via_instance_from`` and used by ``image_for_galaxy`` to
# recover the galaxy's path-tuple key after a JAX unflatten has produced
# fresh ``Galaxy`` objects whose hashes no longer match
# ``galaxy_image_dict`` keys.
self.galaxy_path_list = galaxy_path_list

@property
def mask(self) -> aa.Mask2D:
"""
Expand Down Expand Up @@ -147,7 +156,9 @@ def model_image(self) -> aa.Array2D:

return adapt_model_image

def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
def updated_via_instance_from(
self, instance, mask=None, galaxies: Optional[List["Galaxy"]] = None
) -> "AdaptImages":
"""
Returns adapt-images which have been updated to map galaxy instances instead of galaxy names.

Expand All @@ -170,13 +181,25 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
mask
A mask which can be applied to the adapt images, which is used when setting up the adaptive images
via the aggregator and autofit database tools.
galaxies
Optional list of galaxies in the order used by the calling ``Analysis`` (i.e. the list passed to
``FitImaging`` / ``Tracer``). When provided, a parallel ``galaxy_path_list`` is populated so that
``image_for_galaxy`` can recover the path-tuple key for each galaxy after JAX has unflattened the
galaxy instances into fresh objects. When ``None`` the path list is populated in ``path_instance_tuples_for_class``
order, which matches ``Analysis.galaxies_via_instance_from`` for the common case (no
``extra_galaxies`` / ``scaling_galaxies``).

Returns
-------

"""
from autogalaxy.galaxy.galaxy import Galaxy

path_by_id = {
id(galaxy): str(galaxy_name)
for galaxy_name, galaxy in instance.path_instance_tuples_for_class(Galaxy)
}

galaxy_image_dict = None

if self.galaxy_name_image_dict is not None:
Expand Down Expand Up @@ -207,7 +230,73 @@ def updated_via_instance_from(self, instance, mask=None) -> "AdaptImages":
self.galaxy_name_image_plane_mesh_grid_dict[galaxy_name]
)

if galaxies is not None:
galaxy_path_list = [path_by_id.get(id(g)) for g in galaxies]
else:
galaxy_path_list = [
str(galaxy_name)
for galaxy_name, _ in instance.path_instance_tuples_for_class(Galaxy)
]

return AdaptImages(
galaxy_image_dict=galaxy_image_dict,
galaxy_image_plane_mesh_grid_dict=galaxy_image_plane_mesh_grid_dict,
galaxy_name_image_dict=self.galaxy_name_image_dict,
galaxy_name_image_plane_mesh_grid_dict=self.galaxy_name_image_plane_mesh_grid_dict,
galaxy_path_list=galaxy_path_list,
)

def image_for_galaxy(
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None
) -> Optional[aa.Array2D]:
"""
Return the adapt image for ``galaxy``, robust to JAX ``jit`` boundaries.

``galaxy_image_dict`` is keyed by the trace-time ``Galaxy`` instances. After ``jax.jit`` has flattened
and unflattened a ``FitImaging``, the galaxies inside it are fresh Python objects whose ``__hash__``
differs from the trace-time keys, so a direct lookup misses. This helper falls back to the path-tuple
keyed ``galaxy_name_image_dict`` using ``galaxy_path_list`` to map the post-unflatten galaxy back to
its trace-time path.

Returns ``None`` when no adapt image is associated with the galaxy.
"""
try:
return self.galaxy_image_dict[galaxy]
except (AttributeError, KeyError, TypeError):
pass

path = self._path_for_galaxy(galaxy, galaxies)
if path is None or self.galaxy_name_image_dict is None:
return None
return self.galaxy_name_image_dict.get(path)

def image_plane_mesh_grid_for_galaxy(
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]] = None
) -> Optional[aa.Grid2DIrregular]:
"""
Return the image-plane mesh grid for ``galaxy``, robust to JAX ``jit`` boundaries.

Companion to :meth:`image_for_galaxy` for ``galaxy_image_plane_mesh_grid_dict`` /
``galaxy_name_image_plane_mesh_grid_dict``.
"""
try:
return self.galaxy_image_plane_mesh_grid_dict[galaxy]
except (AttributeError, KeyError, TypeError):
pass

path = self._path_for_galaxy(galaxy, galaxies)
if path is None or self.galaxy_name_image_plane_mesh_grid_dict is None:
return None
return self.galaxy_name_image_plane_mesh_grid_dict.get(path)

def _path_for_galaxy(
self, galaxy: "Galaxy", galaxies: Optional[List["Galaxy"]]
) -> Optional[str]:
if not self.galaxy_path_list or galaxies is None:
return None
for index, candidate in enumerate(galaxies):
if candidate is galaxy:
if index < len(self.galaxy_path_list):
return self.galaxy_path_list[index]
return None
return None
10 changes: 8 additions & 2 deletions autogalaxy/analysis/analysis/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,14 @@ def save_results(self, paths: af.DirectoryPaths, result: ResultDataset):
except AttributeError:
pass

def adapt_images_via_instance_from(self, instance: af.ModelInstance) -> AdaptImages:
def adapt_images_via_instance_from(
self,
instance: af.ModelInstance,
galaxies=None,
) -> AdaptImages:
try:
return self.adapt_images.updated_via_instance_from(instance=instance)
return self.adapt_images.updated_via_instance_from(
instance=instance, galaxies=galaxies
)
except AttributeError:
pass
40 changes: 16 additions & 24 deletions autogalaxy/galaxy/to_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
adapt_images: Optional[AdaptImages] = None,
settings: aa.Settings = None,
xp=np,
path_galaxies: Optional[List[Galaxy]] = None,
):
"""
Interfaces a dataset and input list of galaxies with the inversion module. to setup a
Expand Down Expand Up @@ -247,8 +248,13 @@ def __init__(
the pixelization's pixels to the brightest regions of the image.
settings
The settings of the inversion, which controls how the linear algebra calculation is performed.
path_galaxies
The full ordered list of galaxies that ``adapt_images.galaxy_path_list`` was aligned with (typically
``tracer.galaxies`` in autolens, where ``galaxies`` is a per-plane subset). When ``None`` defaults to
``galaxies`` — correct for the single-plane autogalaxy case.
"""
self.galaxies = Galaxies(galaxies)
self.path_galaxies = path_galaxies if path_galaxies is not None else self.galaxies

super().__init__(
dataset=dataset,
Expand Down Expand Up @@ -418,28 +424,12 @@ def image_plane_mesh_grid_list(
image_plane_mesh_grid_list = []

for galaxy in self.galaxies.galaxies_with_cls_list_from(cls=aa.Pixelization):
try:
image_plane_mesh_grid = (
self.adapt_images.galaxy_image_plane_mesh_grid_dict[galaxy]
)
except (AttributeError, KeyError, TypeError):
if self.adapt_images is None:
image_plane_mesh_grid = None

if image_plane_mesh_grid is None:
# Fallback for JAX JIT: after jax.jit unflatten the galaxy instances
# stored as keys in ``adapt_images`` are stale (different Python objects
# from those in the current tracer), so the dict key lookup above fails
# and yields None. When the dict contains exactly one mesh-grid entry,
# take that single value by insertion order — this is always correct in
# the one-pixelised-source case (Delaunay/Hilbert image-mesh fits).
# fit-imaging-pytree-delaunay
try:
dict_ = self.adapt_images.galaxy_image_plane_mesh_grid_dict
vals = list(dict_.values()) if dict_ else []
if len(vals) == 1:
image_plane_mesh_grid = vals[0]
except (AttributeError, TypeError, KeyError):
pass
else:
image_plane_mesh_grid = self.adapt_images.image_plane_mesh_grid_for_galaxy(
galaxy, self.path_galaxies
)

image_plane_mesh_grid_list.append(image_plane_mesh_grid)

Expand Down Expand Up @@ -551,10 +541,12 @@ def mapper_galaxy_dict(self) -> Dict[aa.Mapper, Galaxy]:
for mapper_index in range(len(mesh_grid_list)):
galaxy = galaxies_with_pixelization_list[mapper_index]

try:
adapt_galaxy_image = self.adapt_images.galaxy_image_dict[galaxy]
except (AttributeError, KeyError):
if self.adapt_images is None:
adapt_galaxy_image = None
else:
adapt_galaxy_image = self.adapt_images.image_for_galaxy(
galaxy, self.path_galaxies
)

mapper = self.mapper_from(
mesh=pixelization_list[mapper_index].mesh,
Expand Down
4 changes: 3 additions & 1 deletion autogalaxy/imaging/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitImaging:

dataset_model = self.dataset_model_via_instance_from(instance=instance)

adapt_images = self.adapt_images_via_instance_from(instance=instance)
adapt_images = self.adapt_images_via_instance_from(
instance=instance, galaxies=galaxies
)

return FitImaging(
dataset=self.dataset,
Expand Down
4 changes: 3 additions & 1 deletion autogalaxy/interferometer/model/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def fit_from(self, instance: af.ModelInstance) -> FitInterferometer:
instance=instance,
)

adapt_images = self.adapt_images_via_instance_from(instance=instance)
adapt_images = self.adapt_images_via_instance_from(
instance=instance, galaxies=galaxies
)

return FitInterferometer(
dataset=self.dataset,
Expand Down
86 changes: 86 additions & 0 deletions test_autogalaxy/analysis/test_adapt_images.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np

import autofit as af
import autogalaxy as ag


Expand All @@ -20,3 +21,88 @@ def test__instance_with_associated_adapt_images_from(masked_imaging_7x7):
assert adapt_images.model_image.native == pytest.approx(
3.0 * np.ones((3, 3)), 1.0e-4
)


def test__image_for_galaxy__resolves_after_galaxy_identity_changes():
"""
Simulates the post-``jax.jit`` unflatten boundary: ``adapt_images.galaxy_image_dict`` is keyed by the
trace-time ``Galaxy`` instances, but the lookup at ``GalaxiesToInversion.mapper_galaxy_dict`` is performed
against fresh ``Galaxy`` objects whose ``__hash__`` differs. The path-tuple lookup via
``galaxy_name_image_dict`` must still resolve to the right adapt image.
"""
galaxies = af.ModelInstance()
galaxies.lens = ag.Galaxy(redshift=0.5)
galaxies.source = ag.Galaxy(redshift=1.0)

instance = af.ModelInstance()
instance.galaxies = galaxies

galaxy_name_image_dict = {
str(("galaxies", "lens")): ag.Array2D.ones(shape_native=(3, 3), pixel_scales=1.0),
str(("galaxies", "source")): ag.Array2D.full(
fill_value=2.0, shape_native=(3, 3), pixel_scales=1.0
),
}

trace_galaxies = [galaxies.lens, galaxies.source]

adapt_images = ag.AdaptImages(
galaxy_name_image_dict=galaxy_name_image_dict,
).updated_via_instance_from(instance=instance, galaxies=trace_galaxies)

assert adapt_images.galaxy_path_list == [
str(("galaxies", "lens")),
str(("galaxies", "source")),
]

# Fast path: by-instance lookup still works for the trace-time galaxies.
assert adapt_images.image_for_galaxy(
trace_galaxies[0], trace_galaxies
).native == pytest.approx(np.ones((3, 3)), 1.0e-4)

# Simulate post-unflatten: fresh ``Galaxy`` objects with new ``.id`` values
# placed at the same positions as the trace-time list. ``galaxy_image_dict``
# cannot resolve them (hash mismatch) so the helper must fall back to
# ``galaxy_name_image_dict`` via ``galaxy_path_list``.
fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)]

assert adapt_images.galaxy_image_dict.get(fresh_galaxies[0]) is None
assert adapt_images.image_for_galaxy(
fresh_galaxies[0], fresh_galaxies
).native == pytest.approx(np.ones((3, 3)), 1.0e-4)
assert adapt_images.image_for_galaxy(
fresh_galaxies[1], fresh_galaxies
).native == pytest.approx(2.0 * np.ones((3, 3)), 1.0e-4)


def test__image_plane_mesh_grid_for_galaxy__resolves_after_galaxy_identity_changes():
"""
Companion to :func:`test__image_for_galaxy__resolves_after_galaxy_identity_changes` for the mesh-grid
lookup path used by ``GalaxiesToInversion.image_plane_mesh_grid_list``.
"""
galaxies = af.ModelInstance()
galaxies.lens = ag.Galaxy(redshift=0.5)
galaxies.source = ag.Galaxy(redshift=1.0)

instance = af.ModelInstance()
instance.galaxies = galaxies

galaxy_name_image_plane_mesh_grid_dict = {
str(("galaxies", "lens")): ag.Grid2DIrregular(values=[(3.0, 3.0), (3.0, 3.0)]),
str(("galaxies", "source")): ag.Grid2DIrregular(values=[(4.0, 4.0), (4.0, 4.0)]),
}

trace_galaxies = [galaxies.lens, galaxies.source]

adapt_images = ag.AdaptImages(
galaxy_name_image_plane_mesh_grid_dict=galaxy_name_image_plane_mesh_grid_dict,
).updated_via_instance_from(instance=instance, galaxies=trace_galaxies)

fresh_galaxies = [ag.Galaxy(redshift=0.5), ag.Galaxy(redshift=1.0)]

assert adapt_images.image_plane_mesh_grid_for_galaxy(
fresh_galaxies[0], fresh_galaxies
) == pytest.approx(3.0 * np.ones((2, 2)), 1.0e-4)
assert adapt_images.image_plane_mesh_grid_for_galaxy(
fresh_galaxies[1], fresh_galaxies
) == pytest.approx(4.0 * np.ones((2, 2)), 1.0e-4)
Loading