Skip to content

fix: tracer_util JAX-safe for traced subhalo redshifts (#498)#499

Merged
Jammy2211 merged 1 commit intomainfrom
feature/subhalo-redshift-jax-fix
May 8, 2026
Merged

fix: tracer_util JAX-safe for traced subhalo redshifts (#498)#499
Jammy2211 merged 1 commit intomainfrom
feature/subhalo-redshift-jax-fix

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Fix issue #498, reported by @qiuhan96: setting subhalo.redshift = af.UniformPrior(...) raised jax.errors.TracerBoolConversionError under jax.jit.

The bug: tracer_util.plane_redshifts_from, tracer_util.planes_from, tracer_util.grid_2d_at_redshift_from, and Tracer.galaxies_ascending_redshift all called Python sorted(galaxies, key=lambda g: g.redshift) on a galaxies list. When one of those redshifts was a JAX traced scalar (e.g. a free-parameter subhalo redshift), sorted performed pairwise < comparisons that JAX cannot lift into traced ops, so the very first call into the multi-plane code path raised.

The fix adds a JAX-aware guard. Each function keeps its existing numpy fast-path for the case where every redshift is concrete (byte-for-byte identical behavior — no production fits change). When any redshift is traced, the function falls through to a partition-and-input-order path: concrete redshifts are deduped by exact float equality, traced redshifts each get their own dedicated plane in input position, and grid_2d_at_redshift_from matches the requested redshift to a galaxy by Python identity (the only call site, AnalysisLens.tracer_via_instance_from, always passes the subhalo's own redshift object).

The integration-test reproducer in autolens_workspace_test PR #79 (Scenario B) now passes — log_likelihood = -3.523166e+05, JIT path matches NumPy within rtol=1e-4. lp.py and the rest of the existing JAX likelihood-function tests are unaffected (regression literals match to within rtol=1e-4).

API Changes

None — internal changes only. All public function signatures are unchanged. Behaviour is unchanged for any model where every galaxy redshift is a concrete Python number; that's the entire test_autolens corpus and every workspace example today. Only previously-failing fits with a traced galaxy redshift see a difference: they now succeed.

See full details below.

Test Plan

  • pytest test_autolens/ — 269/269 passing on the patched library
  • lp.py JAX likelihood-function script — log_likelihood ≈ -1.34797842e+09, matches the existing regression literal at lp.py:201 within rtol=1e-4
  • autolens_workspace_test/scripts/jax_likelihood_functions/imaging/subhalo.py — Scenario A (fixed z=0.55) PASS, Scenario B (free UniformPrior(0.2, 0.9)) now PASS, both produce log_likelihood = -3.523166e+05 with JIT matching NumPy at rtol=1e-4
  • 4 new numpy-only unit tests in test_tracer_util.py covering the partition-and-input-order path
  • Reviewer to spot-check that any model with all concrete redshifts hits the unchanged fast-path (_any_traced guard at the top of each function — when all redshifts return False from _redshift_is_traced, the function body is identical to the prior implementation)
Full API Changes (for automation & release notes)

Removed

None.

Added

  • autolens/lens/tracer_util.py:_redshift_is_traced(redshift) -> bool — module-private helper. Returns True for JAX traced scalars (anything that raises on float(value) and isn't a concrete int / float / 0-d np.ndarray).
  • autolens/lens/tracer_util.py:_any_traced(galaxies) -> bool — module-private helper. Returns True if any galaxy in the list has a traced redshift.

Renamed / Moved

None.

Changed Signature

None.

Changed Behaviour

  • autolens/lens/tracer_util.plane_redshifts_from(galaxies) — when no redshift is traced, behaviour is identical to before. When any redshift is traced, the function returns plane redshifts in input-list order, deduped only by concrete-equality between concrete redshifts. Each traced-redshift galaxy gets its own plane.
  • autolens/lens/tracer_util.planes_from(galaxies, plane_redshifts=None) — same dual-path treatment. The traced-input branch builds plane groupings from input order rather than sorted order.
  • autolens/lens/tracer_util.grid_2d_at_redshift_from(redshift, galaxies, grid, cosmology, xp) — when redshift is traced (or any galaxy has a traced redshift), the function matches the requested redshift to a galaxy by Python is identity (rather than value equality) and returns the multi-plane traced grid at that galaxy's plane. Raises RayTracingException if the traced redshift doesn't belong to any input galaxy by identity. The existing concrete-only path is unchanged.
  • autolens/lens/tracer.Tracer.galaxies_ascending_redshift — when no redshift is traced, returns the same sorted(...) as before. When any redshift is traced, returns list(self.galaxies) (input order trusted).

Migration

No user-facing migration required. Existing scripts hit the unchanged fast-path. Users wishing to make a Galaxy.redshift a free parameter via af.UniformPrior(...) must declare their galaxies in ascending-redshift order in af.Collection(galaxies=...) (which the natural declaration order already produces).

Notes for reviewers

  • The JAX path trusts input galaxy order. Reasonable for the typical bug-report case (Collection(galaxies=Collection(lens=..., subhalo=..., source=...))). If a user puts the subhalo before the lens in declaration order with a redshift between them, the multi-plane scaling factors will be computed in the wrong order — but that's a model-construction error even on the numpy path.
  • Open question (not in scope here): whether the cosmology distance functions in autogalaxy.cosmology accept traced redshifts under JAX. The reproducer reaches them via traced_grid_2d_list_from and produces a finite log-likelihood matching NumPy, so they appear to work for Planck15. If a different cosmology needs a JAX-friendly shim, that's a follow-up issue.
  • analysis/lens.py:116 (instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])) was a suspected breakage site but turns out to work fine — tuple((traced_y, traced_x)) survives JIT tracing intact. No change needed.

🤖 Generated with Claude Code

Free-parameter subhalo redshift (af.UniformPrior on Galaxy.redshift)
raised TracerBoolConversionError under jax.jit because Python sorted()
and float() on a galaxies list containing one traced redshift trip
pairwise boolean comparisons that JAX cannot lift into traced ops.

This patch adds a JAX-aware fast-path guard to:
- tracer_util.plane_redshifts_from
- tracer_util.planes_from
- tracer_util.grid_2d_at_redshift_from
- Tracer.galaxies_ascending_redshift

When every galaxy redshift is a concrete Python number, the existing
sort-based code runs unchanged (numpy fast-path, byte-for-byte
identical to before). When any redshift is traced, the function
instead trusts the input order and grid_2d_at_redshift_from matches
the requested redshift to a galaxy by Python identity (the only call
site, AnalysisLens.tracer_via_instance_from, always passes the
subhalo's own redshift object).

The integration-test reproducer in autolens_workspace_test PR #79
(Scenario B) now passes: log_likelihood = -3.523166e+05, JIT matches
NumPy within rtol=1e-4. The existing lp.py JAX likelihood-function
script still produces -1.34797842e+09, matching its regression
literal. Full test_autolens suite green (269/269).

Adds 4 numpy-only unit tests in test_tracer_util.py covering
_redshift_is_traced detection, partition path input-order
preservation, concrete-redshift dedup, and planes_from grouping.
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 8, 2026
@Jammy2211 Jammy2211 merged commit b790632 into main May 8, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/subhalo-redshift-jax-fix branch May 8, 2026 09:59
@Jammy2211
Copy link
Copy Markdown
Collaborator Author

Workspace PR: PyAutoLabs/autolens_workspace_test#81

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix: subhalo redshift as free parameter raises TracerBoolConversionError under JAX

1 participant