Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 62 additions & 85 deletions scripts/jax_likelihood_functions/imaging/subhalo.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
"""
JAX Reproducer: Free Subhalo Redshift Triggers TracerBoolConversionError
========================================================================

This script is the integration-test reproducer for issue
https://github.com/PyAutoLabs/PyAutoLens/issues/498, reported by @qiuhan96.

When a galaxy named ``subhalo`` is added to the model and its ``redshift`` is
made a free parameter (an ``af.UniformPrior``), the JAX path of
``AnalysisImaging`` raises ``jax.errors.TracerBoolConversionError``. The error
originates in ``autolens/lens/tracer_util.py`` (``plane_redshifts_from`` and
``grid_2d_at_redshift_from``), which perform Python-level ``<``, ``<=``, ``==``
and ``float()`` on values that under JIT are traced scalars.

This script runs two scenarios so the failure mode is unambiguous:

- **Scenario A** — subhalo redshift is a fixed Python float (z=0.55). The JAX
path should succeed (vmap + jit-wrapped ``analysis.fit_from``).
- **Scenario B** — subhalo redshift is ``af.UniformPrior(0.2, 0.9)``. The JAX
path is expected to raise ``TracerBoolConversionError``. A trimmed traceback
is printed and the script reports ``REPRODUCED: issue #498``.

Both scenarios use the same ``dataset/imaging/jax_test`` dataset that
``lp.py`` uses, the same lens (PowerLaw + ExternalShear + linear Sersic
bulge) and the same source (linear Sersic). Only the subhalo galaxy and the
type of its ``redshift`` differ.

Once issue #498 is fixed, Scenario B will stop raising. The script will then
print ``UNEXPECTED: bug appears fixed`` and exit nonzero, which surfaces in
``run_all_scripts.sh`` so we know to flip the assertion polarity.
JAX Regression: Free Subhalo Redshift Under jax.jit
===================================================

This script is the JAX regression check for
https://github.com/PyAutoLabs/PyAutoLens/issues/498 (originally reported by
@qiuhan96 and fixed in PyAutoLens PR #499). When the bug was present, setting
``subhalo.redshift = af.UniformPrior(...)`` raised
``jax.errors.TracerBoolConversionError`` because Python ``sorted`` /
``float()`` / ``<=`` / ``==`` were called on what became a traced scalar
under ``jax.jit``. The fix in ``autolens/lens/tracer_util.py`` and
``Tracer.galaxies_ascending_redshift`` adds a JAX-aware fast-path guard that
trusts input galaxy order when any galaxy redshift is traced.

This script runs two scenarios:

- **Scenario A** — subhalo redshift fixed at z=0.55 (Python float). Exercises
the unchanged numpy fast-path through ``tracer_util``.
- **Scenario B** — subhalo redshift is ``af.UniformPrior(0.2, 0.9)`` (becomes
a traced scalar under ``jax.jit``). Exercises the JAX partition-and-trust-
input-order path that fixes #498.

Both scenarios call ``fitness._vmap`` over a small batch of prior-median
parameter vectors and ``jax.jit``-wrap ``analysis.fit_from`` on a single
instance. Both must produce a finite, NumPy-matching ``log_likelihood`` and
both must produce vmap results matching the regression literals below.
A regression of #498 will trip the ``assert_allclose`` on Scenario B's vmap
output (which would either raise the original ``TracerBoolConversionError``
or silently drift away from the reference value).

Same ``dataset/imaging/jax_test`` dataset, lens, and source as ``lp.py`` —
only the ``subhalo`` galaxy is added.
"""

# %matplotlib inline
Expand All @@ -37,7 +39,6 @@
# print(f"Working Directory has been set to `{workspace_path}`")

import sys
import traceback

import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -134,11 +135,12 @@ def build_model(redshift_subhalo):
__Scenario Runner__

Builds an ``AnalysisImaging`` (with positions likelihood, mirroring ``lp.py``),
runs ``fitness._vmap`` over a small batch of prior-median parameter vectors,
then jit-wraps ``analysis.fit_from`` and runs it on a single instance.

Returns ``(ok, exc)`` so the caller can decide whether the result is what was
expected for that scenario.
runs ``fitness._vmap`` over a small batch of prior-median parameter vectors
and asserts the result matches the regression literal, then jit-wraps
``analysis.fit_from`` on a single instance and asserts the jit log-likelihood
matches the NumPy-path log-likelihood within ``rtol=1e-4``. Any failure
raises (``AssertionError`` or ``TracerBoolConversionError``) and the script
exits non-zero.
"""

from autofit.non_linear.fitness import Fitness
Expand Down Expand Up @@ -178,13 +180,15 @@ def run_scenario(label, redshift_subhalo, batch_size=4):
parameters[i, :] = model.physical_values_from_prior_medians
parameters = jnp.array(parameters)

try:
result = fitness._vmap(parameters)
print(f" [vmap] result shape={np.shape(np.array(result))}, "
f"first={float(np.array(result)[0]):.6e}")
except Exception as exc: # noqa: BLE001 — diagnostic, want to keep going
traceback.print_exc()
return False, exc
result = fitness._vmap(parameters)
print(f" [vmap] result shape={np.shape(np.array(result))}, "
f"first={float(np.array(result)[0]):.6e}")
np.testing.assert_allclose(
np.array(result),
-1.412105e+09,
rtol=1e-4,
err_msg="subhalo: JAX vmap likelihood mismatch (issue #498 regression?)",
)

# --- Path 2: jit-wrapped analysis.fit_from on a single instance ---
instance = model.instance_from_prior_medians()
Expand All @@ -206,67 +210,40 @@ def run_scenario(label, redshift_subhalo, batch_size=4):
],
use_jax=True,
)
try:
fit = jax.jit(analysis_jit.fit_from)(instance)
print(f" [jit] fit.log_likelihood = {float(fit.log_likelihood):.6e}")
np.testing.assert_allclose(
float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4
)
print(f" [jit] matches numpy path within rtol=1e-4")
except Exception as exc: # noqa: BLE001
traceback.print_exc()
return False, exc

return True, None
fit = jax.jit(analysis_jit.fit_from)(instance)
print(f" [jit] fit.log_likelihood = {float(fit.log_likelihood):.6e}")
np.testing.assert_allclose(
float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4
)
print(f" [jit] matches numpy path within rtol=1e-4")


"""
__Scenario A — Fixed Subhalo Redshift__

The subhalo redshift is a Python float between the lens (z=0.5) and source
(z=1.0). This exercises the same multi-plane code path as Scenario B but
without any traced-scalar comparisons, so it should succeed.
(z=1.0). Exercises the unchanged numpy fast-path through ``tracer_util``.
"""
ok_a, exc_a = run_scenario("A (fixed redshift z=0.55)", redshift_subhalo=0.55)
run_scenario("A (fixed redshift z=0.55)", redshift_subhalo=0.55)


"""
__Scenario B — Free Subhalo Redshift__

The subhalo redshift is an ``af.UniformPrior(0.2, 0.9)``. Under JAX this
becomes a traced scalar, and ``tracer_util`` then performs Python boolean
comparisons that JAX cannot lift to traced ops.
The subhalo redshift is an ``af.UniformPrior(0.2, 0.9)`` — a traced scalar
under ``jax.jit``. Exercises the JAX partition-and-trust-input-order path
introduced by PyAutoLens PR #499 to fix #498. The vmap regression literal
``-1.412105e+09`` is identical to Scenario A's because both evaluate at the
prior median ``z_subhalo = 0.55``; what differs is the code path inside
``tracer_util`` (numpy sort vs JAX-aware partition).
"""
ok_b, exc_b = run_scenario(
run_scenario(
"B (free redshift UniformPrior(0.2, 0.9))",
redshift_subhalo=af.UniformPrior(lower_limit=0.2, upper_limit=0.9),
)


"""
__Verdict__
"""
print()
print("=" * 72)
print("Summary")
print("PASS: issue #498 regression check passing")
print("=" * 72)
print(f" Scenario A (fixed) : {'PASS' if ok_a else 'FAIL'}")
print(f" Scenario B (free) : {'PASS' if ok_b else 'FAIL (raised)'}")

if not ok_a:
print()
print(f"FAIL: Scenario A unexpectedly raised: {type(exc_a).__name__}: {exc_a}")
sys.exit(1)

if ok_b:
print()
print("UNEXPECTED: Scenario B did not raise — bug from issue #498 may be fixed.")
print("Update the issue and flip this script's expected outcome.")
sys.exit(1)

# Scenario B raised, which is what we expect today.
exc_name = type(exc_b).__name__
print()
print(f"REPRODUCED: issue #498")
print(f" Scenario B raised {exc_name}: {exc_b}")
print()
Loading