Skip to content

test: add JAX subhalo-redshift reproducer for issue #498#79

Merged
Jammy2211 merged 1 commit intomainfrom
feature/subhalo-redshift-jax-repro
May 8, 2026
Merged

test: add JAX subhalo-redshift reproducer for issue #498#79
Jammy2211 merged 1 commit intomainfrom
feature/subhalo-redshift-jax-repro

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Adds an integration-test reproducer for PyAutoLens issue #498, reported by @qiuhan96.

The bug: setting subhalo.redshift = af.UniformPrior(...) raises jax.errors.TracerBoolConversionError under jax.jit, because autolens/lens/tracer_util.plane_redshifts_from calls Python's sorted(galaxies, key=lambda g: g.redshift) on a list containing one traced redshift scalar.

This script reproduces the failure cleanly and exits 0 while the bug is open. Once the fix lands in PyAutoLens, Scenario B will start passing and the script will exit 1, surfacing in run_all_scripts.sh so the assertion polarity is flipped and the same file becomes the regression test.

Scripts Changed

  • scripts/jax_likelihood_functions/imaging/subhalo.py — new. Reproduces issue #498. Two scenarios:
    • A (fixed redshift_subhalo=0.55): vmap + jit-wrapped analysis.fit_from, expected PASS, asserts JIT path matches NumPy path within rtol=1e-4.
    • B (free redshift_subhalo=af.UniformPrior(0.2, 0.9)): same calls, expected to raise TracerBoolConversionError from tracer_util.py:46. Prints REPRODUCED: issue #498 on the expected raise, UNEXPECTED: bug appears fixed (and exits 1) once the upstream fix lands.

Test Plan

  • Scenario A passes (vmap finite, jit log_likelihood matches NumPy at -3.52e+05 within rtol=1e-4)
  • Scenario B raises TracerBoolConversionError at tracer_util.py:46 (sorted(galaxies, ...)) as expected
  • Existing smoke tests still pass (run via /smoke_test)

🤖 Generated with Claude Code

Adds scripts/jax_likelihood_functions/imaging/subhalo.py, which mirrors
lp.py but adds a `subhalo` galaxy to the model. Runs two scenarios:

- Scenario A: fixed subhalo redshift (z=0.55) -> PASS under jax.jit + vmap
- Scenario B: free `af.UniformPrior(0.2, 0.9)` redshift -> raises
  TracerBoolConversionError, as reported by @qiuhan96

The script exits 0 today (Scenario B failing is the expected outcome).
Once the underlying tracer_util bug is fixed in PyAutoLens it will
exit 1 on the same path, prompting the assertion polarity flip and
turning the script into a regression test.
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 8, 2026
@Jammy2211 Jammy2211 merged commit d827d1c into main May 8, 2026
2 of 4 checks passed
@Jammy2211 Jammy2211 deleted the feature/subhalo-redshift-jax-repro branch May 8, 2026 08:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant