diff --git a/scripts/jax_likelihood_functions/imaging/subhalo.py b/scripts/jax_likelihood_functions/imaging/subhalo.py index 4f521c5..fc4603e 100644 --- a/scripts/jax_likelihood_functions/imaging/subhalo.py +++ b/scripts/jax_likelihood_functions/imaging/subhalo.py @@ -1,33 +1,35 @@ """ -JAX Reproducer: Free Subhalo Redshift Triggers TracerBoolConversionError -======================================================================== - -This script is the integration-test reproducer for issue -https://github.com/PyAutoLabs/PyAutoLens/issues/498, reported by @qiuhan96. - -When a galaxy named ``subhalo`` is added to the model and its ``redshift`` is -made a free parameter (an ``af.UniformPrior``), the JAX path of -``AnalysisImaging`` raises ``jax.errors.TracerBoolConversionError``. The error -originates in ``autolens/lens/tracer_util.py`` (``plane_redshifts_from`` and -``grid_2d_at_redshift_from``), which perform Python-level ``<``, ``<=``, ``==`` -and ``float()`` on values that under JIT are traced scalars. - -This script runs two scenarios so the failure mode is unambiguous: - -- **Scenario A** — subhalo redshift is a fixed Python float (z=0.55). The JAX - path should succeed (vmap + jit-wrapped ``analysis.fit_from``). -- **Scenario B** — subhalo redshift is ``af.UniformPrior(0.2, 0.9)``. The JAX - path is expected to raise ``TracerBoolConversionError``. A trimmed traceback - is printed and the script reports ``REPRODUCED: issue #498``. - -Both scenarios use the same ``dataset/imaging/jax_test`` dataset that -``lp.py`` uses, the same lens (PowerLaw + ExternalShear + linear Sersic -bulge) and the same source (linear Sersic). Only the subhalo galaxy and the -type of its ``redshift`` differ. - -Once issue #498 is fixed, Scenario B will stop raising. The script will then -print ``UNEXPECTED: bug appears fixed`` and exit nonzero, which surfaces in -``run_all_scripts.sh`` so we know to flip the assertion polarity. +JAX Regression: Free Subhalo Redshift Under jax.jit +=================================================== + +This script is the JAX regression check for +https://github.com/PyAutoLabs/PyAutoLens/issues/498 (originally reported by +@qiuhan96 and fixed in PyAutoLens PR #499). When the bug was present, setting +``subhalo.redshift = af.UniformPrior(...)`` raised +``jax.errors.TracerBoolConversionError`` because Python ``sorted`` / +``float()`` / ``<=`` / ``==`` were called on what became a traced scalar +under ``jax.jit``. The fix in ``autolens/lens/tracer_util.py`` and +``Tracer.galaxies_ascending_redshift`` adds a JAX-aware fast-path guard that +trusts input galaxy order when any galaxy redshift is traced. + +This script runs two scenarios: + +- **Scenario A** — subhalo redshift fixed at z=0.55 (Python float). Exercises + the unchanged numpy fast-path through ``tracer_util``. +- **Scenario B** — subhalo redshift is ``af.UniformPrior(0.2, 0.9)`` (becomes + a traced scalar under ``jax.jit``). Exercises the JAX partition-and-trust- + input-order path that fixes #498. + +Both scenarios call ``fitness._vmap`` over a small batch of prior-median +parameter vectors and ``jax.jit``-wrap ``analysis.fit_from`` on a single +instance. Both must produce a finite, NumPy-matching ``log_likelihood`` and +both must produce vmap results matching the regression literals below. +A regression of #498 will trip the ``assert_allclose`` on Scenario B's vmap +output (which would either raise the original ``TracerBoolConversionError`` +or silently drift away from the reference value). + +Same ``dataset/imaging/jax_test`` dataset, lens, and source as ``lp.py`` — +only the ``subhalo`` galaxy is added. """ # %matplotlib inline @@ -37,7 +39,6 @@ # print(f"Working Directory has been set to `{workspace_path}`") import sys -import traceback import numpy as np import jax.numpy as jnp @@ -134,11 +135,12 @@ def build_model(redshift_subhalo): __Scenario Runner__ Builds an ``AnalysisImaging`` (with positions likelihood, mirroring ``lp.py``), -runs ``fitness._vmap`` over a small batch of prior-median parameter vectors, -then jit-wraps ``analysis.fit_from`` and runs it on a single instance. - -Returns ``(ok, exc)`` so the caller can decide whether the result is what was -expected for that scenario. +runs ``fitness._vmap`` over a small batch of prior-median parameter vectors +and asserts the result matches the regression literal, then jit-wraps +``analysis.fit_from`` on a single instance and asserts the jit log-likelihood +matches the NumPy-path log-likelihood within ``rtol=1e-4``. Any failure +raises (``AssertionError`` or ``TracerBoolConversionError``) and the script +exits non-zero. """ from autofit.non_linear.fitness import Fitness @@ -178,13 +180,15 @@ def run_scenario(label, redshift_subhalo, batch_size=4): parameters[i, :] = model.physical_values_from_prior_medians parameters = jnp.array(parameters) - try: - result = fitness._vmap(parameters) - print(f" [vmap] result shape={np.shape(np.array(result))}, " - f"first={float(np.array(result)[0]):.6e}") - except Exception as exc: # noqa: BLE001 — diagnostic, want to keep going - traceback.print_exc() - return False, exc + result = fitness._vmap(parameters) + print(f" [vmap] result shape={np.shape(np.array(result))}, " + f"first={float(np.array(result)[0]):.6e}") + np.testing.assert_allclose( + np.array(result), + -1.412105e+09, + rtol=1e-4, + err_msg="subhalo: JAX vmap likelihood mismatch (issue #498 regression?)", + ) # --- Path 2: jit-wrapped analysis.fit_from on a single instance --- instance = model.instance_from_prior_medians() @@ -206,67 +210,40 @@ def run_scenario(label, redshift_subhalo, batch_size=4): ], use_jax=True, ) - try: - fit = jax.jit(analysis_jit.fit_from)(instance) - print(f" [jit] fit.log_likelihood = {float(fit.log_likelihood):.6e}") - np.testing.assert_allclose( - float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 - ) - print(f" [jit] matches numpy path within rtol=1e-4") - except Exception as exc: # noqa: BLE001 - traceback.print_exc() - return False, exc - - return True, None + fit = jax.jit(analysis_jit.fit_from)(instance) + print(f" [jit] fit.log_likelihood = {float(fit.log_likelihood):.6e}") + np.testing.assert_allclose( + float(fit.log_likelihood), float(fit_np.log_likelihood), rtol=1e-4 + ) + print(f" [jit] matches numpy path within rtol=1e-4") """ __Scenario A — Fixed Subhalo Redshift__ The subhalo redshift is a Python float between the lens (z=0.5) and source -(z=1.0). This exercises the same multi-plane code path as Scenario B but -without any traced-scalar comparisons, so it should succeed. +(z=1.0). Exercises the unchanged numpy fast-path through ``tracer_util``. """ -ok_a, exc_a = run_scenario("A (fixed redshift z=0.55)", redshift_subhalo=0.55) +run_scenario("A (fixed redshift z=0.55)", redshift_subhalo=0.55) """ __Scenario B — Free Subhalo Redshift__ -The subhalo redshift is an ``af.UniformPrior(0.2, 0.9)``. Under JAX this -becomes a traced scalar, and ``tracer_util`` then performs Python boolean -comparisons that JAX cannot lift to traced ops. +The subhalo redshift is an ``af.UniformPrior(0.2, 0.9)`` — a traced scalar +under ``jax.jit``. Exercises the JAX partition-and-trust-input-order path +introduced by PyAutoLens PR #499 to fix #498. The vmap regression literal +``-1.412105e+09`` is identical to Scenario A's because both evaluate at the +prior median ``z_subhalo = 0.55``; what differs is the code path inside +``tracer_util`` (numpy sort vs JAX-aware partition). """ -ok_b, exc_b = run_scenario( +run_scenario( "B (free redshift UniformPrior(0.2, 0.9))", redshift_subhalo=af.UniformPrior(lower_limit=0.2, upper_limit=0.9), ) -""" -__Verdict__ -""" print() print("=" * 72) -print("Summary") +print("PASS: issue #498 regression check passing") print("=" * 72) -print(f" Scenario A (fixed) : {'PASS' if ok_a else 'FAIL'}") -print(f" Scenario B (free) : {'PASS' if ok_b else 'FAIL (raised)'}") - -if not ok_a: - print() - print(f"FAIL: Scenario A unexpectedly raised: {type(exc_a).__name__}: {exc_a}") - sys.exit(1) - -if ok_b: - print() - print("UNEXPECTED: Scenario B did not raise — bug from issue #498 may be fixed.") - print("Update the issue and flip this script's expected outcome.") - sys.exit(1) - -# Scenario B raised, which is what we expect today. -exc_name = type(exc_b).__name__ -print() -print(f"REPRODUCED: issue #498") -print(f" Scenario B raised {exc_name}: {exc_b}") -print()