Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion autolens/lens/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,23 @@ def galaxies_ascending_redshift(self) -> List[ag.Galaxy]:
planes of increasing redshift. Thus, the galaxies are sorted by redshift in ascending order to aid this
calculation.

When any galaxy has a JAX-traced ``redshift`` (e.g. a free-parameter
subhalo redshift under ``jax.jit``), Python's ``sorted`` cannot order
the list because the key comparator would coerce a traced boolean.
In that case, we trust the input order — the caller (typically
``af.Collection(galaxies=...)``) is expected to declare galaxies in
ascending-redshift order, with each traced-redshift galaxy placed at
its intended plane position. See ``tracer_util.plane_redshifts_from``
for the matching rule.

Returns
-------
The galaxies in the tracer in ascending redshift order.
"""
return sorted(self.galaxies, key=lambda galaxy: galaxy.redshift)
if not tracer_util._any_traced(self.galaxies):
return sorted(self.galaxies, key=lambda galaxy: galaxy.redshift)

return list(self.galaxies)

@property
def plane_redshifts(self) -> List[float]:
Expand Down
135 changes: 121 additions & 14 deletions autolens/lens/tracer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,33 @@
from autolens import exc


def _redshift_is_traced(redshift) -> bool:
"""
Return True if ``redshift`` is a JAX traced scalar that cannot be coerced to a
Python float without raising under ``jax.jit``.

Galaxy redshifts are normally Python ``float`` / ``int`` values, but when a
``af.UniformPrior`` is bound to a ``Galaxy.redshift`` field (e.g. for a free-
parameter subhalo redshift; see PyAutoLens issue #498), the value passed in at
likelihood-evaluation time becomes a traced scalar under ``jax.jit``. Most
sort-and-compare helpers in this module need to fall back to a JAX-aware path
in that case rather than calling ``sorted`` / ``float()`` / ``<=`` on the value.
"""
if isinstance(redshift, (int, float)):
return False
if isinstance(redshift, np.ndarray) and redshift.shape == ():
return False
try:
float(redshift)
except Exception:
return True
return False


def _any_traced(galaxies: List[ag.Galaxy]) -> bool:
return any(_redshift_is_traced(g.redshift) for g in galaxies)


def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]:
"""
Returns a list of plane redshifts from a list of galaxies, using the redshifts of the galaxies to determine the
Expand All @@ -33,6 +60,14 @@ def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]:
For example, if the input is three galaxies, two at redshift 1.0 and one at redshift 2.0, the returned list of
redshifts would be [1.0, 2.0].

When one or more galaxies have a JAX-traced redshift (e.g. a free-parameter
subhalo redshift under ``jax.jit``), the function cannot Python-sort or
``float()``-coerce the values. It instead walks the input list in order,
deduplicating *concrete* redshifts only and treating each traced redshift as a
unique plane at its input position. The caller must pass galaxies in
ascending-redshift order in this case (which ``af.Collection(galaxies=...)``
naturally does when the user declares them as ``lens, subhalo, source``).

Parameters
----------
galaxies
Expand All @@ -43,12 +78,27 @@ def plane_redshifts_from(galaxies: List[ag.Galaxy]) -> List[float]:
The list of unique redshifts of the planes.
"""

galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)
if not _any_traced(galaxies):
galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)

# Coerce to float to avoid issues with other float types not being hashable.
plane_redshifts = [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift]

return list(dict.fromkeys(plane_redshifts))

# Coerce to float to avoid issues with other float types not being hashable.
plane_redshifts = [float(galaxy.redshift) for galaxy in galaxies_ascending_redshift]
plane_redshifts: List = []
seen_concrete: set = set()
for galaxy in galaxies:
z = galaxy.redshift
if _redshift_is_traced(z):
plane_redshifts.append(z)
else:
zf = float(z)
if zf not in seen_concrete:
seen_concrete.add(zf)
plane_redshifts.append(zf)

return list(dict.fromkeys(plane_redshifts))
return plane_redshifts


def planes_from(
Expand All @@ -68,6 +118,12 @@ def planes_from(
For example, if the input is three galaxies, two at redshift 1.0 and one at redshift 2.0, the returned list of
list of galaxies would be [[g1, g2], g3]].

When any galaxy has a JAX-traced redshift, planes are built by walking the
input galaxies in order and grouping by *concrete* redshift equality only;
each traced-redshift galaxy gets its own dedicated plane in input position.
See ``plane_redshifts_from`` for the matching rule and assumption on input
ordering.

Parameters
----------
galaxies
Expand All @@ -81,21 +137,38 @@ def planes_from(
The list of list of galaxies grouped into their planes.
"""

galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)
if not _any_traced(galaxies):
galaxies_ascending_redshift = sorted(galaxies, key=lambda galaxy: galaxy.redshift)

if plane_redshifts is None:
plane_redshifts = plane_redshifts_from(galaxies=galaxies_ascending_redshift)

planes = [[] for i in range(len(plane_redshifts))]

if plane_redshifts is None:
plane_redshifts = plane_redshifts_from(galaxies=galaxies_ascending_redshift)
for galaxy in galaxies_ascending_redshift:
index = (np.abs(np.asarray(plane_redshifts) - galaxy.redshift)).argmin()
planes[index].append(galaxy)

planes = [[] for i in range(len(plane_redshifts))]
for index in range(len(planes)):
planes[index] = ag.Galaxies(galaxies=planes[index])

for galaxy in galaxies_ascending_redshift:
index = (np.abs(np.asarray(plane_redshifts) - galaxy.redshift)).argmin()
planes[index].append(galaxy)
return planes

for index in range(len(planes)):
planes[index] = ag.Galaxies(galaxies=planes[index])
plane_groups: List = [] # list of (key, [galaxies])
for galaxy in galaxies:
z = galaxy.redshift
if _redshift_is_traced(z):
plane_groups.append((z, [galaxy]))
else:
zf = float(z)
for i, (key, _) in enumerate(plane_groups):
if not _redshift_is_traced(key) and float(key) == zf:
plane_groups[i][1].append(galaxy)
break
else:
plane_groups.append((zf, [galaxy]))

return planes
return [ag.Galaxies(galaxies=group) for _, group in plane_groups]


def traced_grid_2d_list_from(
Expand Down Expand Up @@ -244,6 +317,40 @@ def grid_2d_at_redshift_from(
"""
cosmology = cosmology or ag.cosmo.Planck15()

if _redshift_is_traced(redshift) or _any_traced(galaxies):
# JAX path: the requested redshift always matches the redshift of one of
# the input galaxies (this is how AnalysisLens.tracer_via_instance_from
# invokes the function — it passes ``redshift=instance.galaxies.subhalo.
# redshift`` and the subhalo galaxy is in ``galaxies`` too). So we just
# need to identify which plane that galaxy lives in (via Python identity,
# not value comparison) and return the traced grid at that plane.
planes = planes_from(galaxies=galaxies)

plane_index_match = None
for plane_index, plane_galaxies in enumerate(planes):
for plane_galaxy in plane_galaxies:
if plane_galaxy.redshift is redshift:
plane_index_match = plane_index
break
if plane_index_match is not None:
break

if plane_index_match is None:
raise exc.RayTracingException(
"grid_2d_at_redshift_from was called under JAX with a traced "
"redshift that does not match any galaxy in the input list by "
"Python identity. The current implementation only supports the "
"case where the requested redshift is the same object as one of "
"the galaxy redshifts (e.g. instance.galaxies.subhalo.redshift). "
"Insertion at an arbitrary traced redshift is not yet supported."
)

traced_grid_list = traced_grid_2d_list_from(
planes=planes, grid=grid, cosmology=cosmology, xp=xp
)

return traced_grid_list[plane_index_match]

plane_redshifts = plane_redshifts_from(galaxies=galaxies)

if redshift <= plane_redshifts[0]:
Expand Down
77 changes: 77 additions & 0 deletions test_autolens/lens/test_tracer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,83 @@ def test__grid_2d_at_redshift_from__redshift_between_planes(grid_2d_7x7):
assert (grid_at_redshift == grid_2d_7x7.mask.derive_grid.all_false).all()


class _FakeTracedRedshift:
"""A redshift-like object that mimics a JAX traced scalar — calling ``float()``
raises, so ``tracer_util._redshift_is_traced`` should return True. Used to
exercise the JAX partition-and-splice path without importing ``jax`` (library
unit tests stay numpy-only — see ``feedback_no_jax_in_unit_tests``)."""

def __init__(self, name: str):
self.name = name

def __float__(self):
raise TypeError("traced redshift cannot be coerced to float")

def __repr__(self):
return f"<TracedRedshift {self.name}>"


def test__redshift_is_traced__detects_traced_and_concrete():
from autolens.lens import tracer_util

assert tracer_util._redshift_is_traced(_FakeTracedRedshift("subhalo")) is True

assert tracer_util._redshift_is_traced(0.5) is False
assert tracer_util._redshift_is_traced(1) is False
assert tracer_util._redshift_is_traced(np.float64(1.5)) is False
assert tracer_util._redshift_is_traced(np.array(0.7)) is False


def test__plane_redshifts_from__partition_path__preserves_input_order():
from autolens.lens import tracer_util

lens = al.Galaxy(redshift=0.5)
subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo"))
source = al.Galaxy(redshift=1.0)

plane_redshifts = tracer_util.plane_redshifts_from(
galaxies=[lens, subhalo, source]
)

assert plane_redshifts[0] == 0.5
assert isinstance(plane_redshifts[1], _FakeTracedRedshift)
assert plane_redshifts[2] == 1.0


def test__plane_redshifts_from__partition_path__dedupes_concrete_only():
from autolens.lens import tracer_util

g0 = al.Galaxy(redshift=0.5)
g1 = al.Galaxy(redshift=0.5) # duplicate concrete redshift — should collapse
subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo"))
source = al.Galaxy(redshift=1.0)

plane_redshifts = tracer_util.plane_redshifts_from(
galaxies=[g0, g1, subhalo, source]
)

assert plane_redshifts[0] == 0.5
assert isinstance(plane_redshifts[1], _FakeTracedRedshift)
assert plane_redshifts[2] == 1.0
assert len(plane_redshifts) == 3


def test__planes_from__partition_path__traced_galaxy_gets_dedicated_plane():
from autolens.lens import tracer_util

lens_a = al.Galaxy(redshift=0.5)
lens_b = al.Galaxy(redshift=0.5) # same plane as lens_a
subhalo = al.Galaxy(redshift=_FakeTracedRedshift("subhalo"))
source = al.Galaxy(redshift=1.0)

planes = tracer_util.planes_from(galaxies=[lens_a, lens_b, subhalo, source])

assert len(planes) == 3
assert list(planes[0]) == [lens_a, lens_b]
assert list(planes[1]) == [subhalo]
assert list(planes[2]) == [source]


def test__time_delays_from():

grid = al.Grid2DIrregular(values=[(0.7, 0.5), (1.0, 1.0)])
Expand Down
Loading