diff --git a/autolens/lens/tracer.py b/autolens/lens/tracer.py index 5f5bfffa4..8eedb22b1 100644 --- a/autolens/lens/tracer.py +++ b/autolens/lens/tracer.py @@ -86,11 +86,23 @@ def galaxies_ascending_redshift(self) -> List[ag.Galaxy]: planes of increasing redshift. Thus, the galaxies are sorted by redshift in ascending order to aid this calculation. + When any galaxy has a JAX-traced ``redshift`` (e.g. a free-parameter + subhalo redshift under ``jax.jit``), Python's ``sorted`` cannot order + the list because the key comparator would coerce a traced boolean. + In that case, we trust the input order — the caller (typically + ``af.Collection(galaxies=...)``) is expected to declare galaxies in + ascending-redshift order, with each traced-redshift galaxy placed at + its intended plane position. See ``tracer_util.plane_redshifts_from`` + for the matching rule. + Returns ------- The galaxies in the tracer in ascending redshift order. """ - return sorted(self.galaxies, key=lambda galaxy: galaxy.redshift) + if not tracer_util._any_traced(self.galaxies): + return sorted(self.galaxies, key=lambda galaxy: galaxy.redshift) + + return list(self.galaxies) @property def plane_redshifts(self) -> List[float]: diff --git a/autolens/lens/tracer_util.py b/autolens/lens/tracer_util.py index 46f32066d..541495d6a 100644 --- a/autolens/lens/tracer_util.py +++ b/autolens/lens/tracer_util.py @@ -21,6 +21,33 @@ from autolens import exc +def _redshift_is_traced(redshift) -> bool: + """ + Return True if ``redshift`` is a JAX traced scalar that cannot be coerced to a + Python float without raising under ``jax.jit``. + + Galaxy redshifts are normally Python ``float`` / ``int`` values, but when a + ``af.UniformPrior`` is bound to a ``Galaxy.redshift`` field (e.g. for a free- + parameter subhalo redshift; see PyAutoLens issue #498), the value passed in at + likelihood-evaluation time becomes a traced scalar under ``jax.jit``. Most + sort-and-compare helpers in this module need to fall back to a JAX-aware path + in that case rather than calling ``sorted`` / ``float()`` / ``<=`` on the value. + """ + if isinstance(redshift, (int, float)): + return False + if isinstance(redshift, np.ndarray) and redshift.shape == (): + return False + try: + float(redshift) + except Exception: + return True + return False + + +def _any_traced(galaxies: List[ag.Galaxy]) -> bool: + return any(_redshift_is_traced(g.redshift) for g in galaxies) + + def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]: """ Returns a list of plane redshifts from a list of galaxies, using the redshifts of the galaxies to determine the @@ -33,6 +60,14 @@ def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]: For example, if the input is three galaxies, two at redshift 1.0 and one at redshift 2.0, the returned list of redshifts would be [1.0, 2.0]. + When one or more galaxies have a JAX-traced redshift (e.g. a free-parameter + subhalo redshift under ``jax.jit``), the function cannot Python-sort or + ``float()``-coerce the values. It instead walks the input list in order, + deduplicating *concrete* redshifts only and treating each traced redshift as a + unique plane at its input position. The caller must pass galaxies in + ascending-redshift order in this case (which ``af.Collection(galaxies=...)`` + naturally does when the user declares them as ``lens, subhalo, source``). + Parameters ---------- galaxies @@ -43,12 +78,27 @@ def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]: The list of unique redshifts of the planes. """ - galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift) + if not _any_traced(galaxies): + galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift) + + # Coerce to float to avoid issues with other float types not being hashable. + plane_redshifts = [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift] + + return list(dict.fromkeys(plane_redshifts)) - # Coerce to float to avoid issues with other float types not being hashable. - plane_redshifts = [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift] + plane_redshifts: List = [] + seen_concrete: set = set() + for galaxy in galaxies: + z = galaxy.redshift + if _redshift_is_traced(z): + plane_redshifts.append(z) + else: + zf = float(z) + if zf not in seen_concrete: + seen_concrete.add(zf) + plane_redshifts.append(zf) - return list(dict.fromkeys(plane_redshifts)) + return plane_redshifts def planes_from( @@ -68,6 +118,12 @@ def planes_from( For example, if the input is three galaxies, two at redshift 1.0 and one at redshift 2.0, the returned list of list of galaxies would be [[g1, g2], g3]]. + When any galaxy has a JAX-traced redshift, planes are built by walking the + input galaxies in order and grouping by *concrete* redshift equality only; + each traced-redshift galaxy gets its own dedicated plane in input position. + See ``plane_redshifts_from`` for the matching rule and assumption on input + ordering. + Parameters ---------- galaxies @@ -81,21 +137,38 @@ def planes_from( The list of list of galaxies grouped into their planes. """ - galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift) + if not _any_traced(galaxies): + galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift) + + if plane_redshifts is None: + plane_redshifts = plane_redshifts_from(galaxies=galaxies_ascending_redshift) + + planes = [[] for i in range(len(plane_redshifts))] - if plane_redshifts is None: - plane_redshifts = plane_redshifts_from(galaxies=galaxies_ascending_redshift) + for galaxy in galaxies_ascending_redshift: + index = (np.abs(np.asarray(plane_redshifts) - galaxy.redshift)).argmin() + planes[index].append(galaxy) - planes = [[] for i in range(len(plane_redshifts))] + for index in range(len(planes)): + planes[index] = ag.Galaxies(galaxies=planes[index]) - for galaxy in galaxies_ascending_redshift: - index = (np.abs(np.asarray(plane_redshifts) - galaxy.redshift)).argmin() - planes[index].append(galaxy) + return planes - for index in range(len(planes)): - planes[index] = ag.Galaxies(galaxies=planes[index]) + plane_groups: List = [] # list of (key, [galaxies]) + for galaxy in galaxies: + z = galaxy.redshift + if _redshift_is_traced(z): + plane_groups.append((z, [galaxy])) + else: + zf = float(z) + for i, (key, _) in enumerate(plane_groups): + if not _redshift_is_traced(key) and float(key) == zf: + plane_groups[i][1].append(galaxy) + break + else: + plane_groups.append((zf, [galaxy])) - return planes + return [ag.Galaxies(galaxies=group) for _, group in plane_groups] def traced_grid_2d_list_from( @@ -244,6 +317,40 @@ def grid_2d_at_redshift_from( """ cosmology = cosmology or ag.cosmo.Planck15() + if _redshift_is_traced(redshift) or _any_traced(galaxies): + # JAX path: the requested redshift always matches the redshift of one of + # the input galaxies (this is how AnalysisLens.tracer_via_instance_from + # invokes the function — it passes ``redshift=instance.galaxies.subhalo. + # redshift`` and the subhalo galaxy is in ``galaxies`` too). So we just + # need to identify which plane that galaxy lives in (via Python identity, + # not value comparison) and return the traced grid at that plane. + planes = planes_from(galaxies=galaxies) + + plane_index_match = None + for plane_index, plane_galaxies in enumerate(planes): + for plane_galaxy in plane_galaxies: + if plane_galaxy.redshift is redshift: + plane_index_match = plane_index + break + if plane_index_match is not None: + break + + if plane_index_match is None: + raise exc.RayTracingException( + "grid_2d_at_redshift_from was called under JAX with a traced " + "redshift that does not match any galaxy in the input list by " + "Python identity. The current implementation only supports the " + "case where the requested redshift is the same object as one of " + "the galaxy redshifts (e.g. instance.galaxies.subhalo.redshift). " + "Insertion at an arbitrary traced redshift is not yet supported." + ) + + traced_grid_list = traced_grid_2d_list_from( + planes=planes, grid=grid, cosmology=cosmology, xp=xp + ) + + return traced_grid_list[plane_index_match] + plane_redshifts = plane_redshifts_from(galaxies=galaxies) if redshift <= plane_redshifts[0]: diff --git a/test_autolens/lens/test_tracer_util.py b/test_autolens/lens/test_tracer_util.py index 023f7cf98..285e3784a 100644 --- a/test_autolens/lens/test_tracer_util.py +++ b/test_autolens/lens/test_tracer_util.py @@ -175,6 +175,83 @@ def test__grid_2d_at_redshift_from__redshift_between_planes(grid_2d_7x7): assert (grid_at_redshift == grid_2d_7x7.mask.derive_grid.all_false).all() +class _FakeTracedRedshift: + """A redshift-like object that mimics a JAX traced scalar — calling ``float()`` + raises, so ``tracer_util._redshift_is_traced`` should return True. Used to + exercise the JAX partition-and-splice path without importing ``jax`` (library + unit tests stay numpy-only — see ``feedback_no_jax_in_unit_tests``).""" + + def __init__(self, name: str): + self.name = name + + def __float__(self): + raise TypeError("traced redshift cannot be coerced to float") + + def __repr__(self): + return f"" + + +def test__redshift_is_traced__detects_traced_and_concrete(): + from autolens.lens import tracer_util + + assert tracer_util._redshift_is_traced(_FakeTracedRedshift("subhalo")) is True + + assert tracer_util._redshift_is_traced(0.5) is False + assert tracer_util._redshift_is_traced(1) is False + assert tracer_util._redshift_is_traced(np.float64(1.5)) is False + assert tracer_util._redshift_is_traced(np.array(0.7)) is False + + +def test__plane_redshifts_from__partition_path__preserves_input_order(): + from autolens.lens import tracer_util + + lens = al.Galaxy(redshift=0.5) + subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo")) + source = al.Galaxy(redshift=1.0) + + plane_redshifts = tracer_util.plane_redshifts_from( + galaxies=[lens, subhalo, source] + ) + + assert plane_redshifts[0] == 0.5 + assert isinstance(plane_redshifts[1], _FakeTracedRedshift) + assert plane_redshifts[2] == 1.0 + + +def test__plane_redshifts_from__partition_path__dedupes_concrete_only(): + from autolens.lens import tracer_util + + g0 = al.Galaxy(redshift=0.5) + g1 = al.Galaxy(redshift=0.5) # duplicate concrete redshift — should collapse + subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo")) + source = al.Galaxy(redshift=1.0) + + plane_redshifts = tracer_util.plane_redshifts_from( + galaxies=[g0, g1, subhalo, source] + ) + + assert plane_redshifts[0] == 0.5 + assert isinstance(plane_redshifts[1], _FakeTracedRedshift) + assert plane_redshifts[2] == 1.0 + assert len(plane_redshifts) == 3 + + +def test__planes_from__partition_path__traced_galaxy_gets_dedicated_plane(): + from autolens.lens import tracer_util + + lens_a = al.Galaxy(redshift=0.5) + lens_b = al.Galaxy(redshift=0.5) # same plane as lens_a + subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo")) + source = al.Galaxy(redshift=1.0) + + planes = tracer_util.planes_from(galaxies=[lens_a, lens_b, subhalo, source]) + + assert len(planes) == 3 + assert list(planes[0]) == [lens_a, lens_b] + assert list(planes[1]) == [subhalo] + assert list(planes[2]) == [source] + + def test__time_delays_from(): grid = al.Grid2DIrregular(values=[(0.7, 0.5), (1.0, 1.0)])