fix: tracer_util JAX-safe for traced subhalo redshifts (#498)#499
Merged
fix: tracer_util JAX-safe for traced subhalo redshifts (#498)#499
Conversation
Free-parameter subhalo redshift (af.UniformPrior on Galaxy.redshift) raised TracerBoolConversionError under jax.jit because Python sorted() and float() on a galaxies list containing one traced redshift trip pairwise boolean comparisons that JAX cannot lift into traced ops. This patch adds a JAX-aware fast-path guard to: - tracer_util.plane_redshifts_from - tracer_util.planes_from - tracer_util.grid_2d_at_redshift_from - Tracer.galaxies_ascending_redshift When every galaxy redshift is a concrete Python number, the existing sort-based code runs unchanged (numpy fast-path, byte-for-byte identical to before). When any redshift is traced, the function instead trusts the input order and grid_2d_at_redshift_from matches the requested redshift to a galaxy by Python identity (the only call site, AnalysisLens.tracer_via_instance_from, always passes the subhalo's own redshift object). The integration-test reproducer in autolens_workspace_test PR #79 (Scenario B) now passes: log_likelihood = -3.523166e+05, JIT matches NumPy within rtol=1e-4. The existing lp.py JAX likelihood-function script still produces -1.34797842e+09, matching its regression literal. Full test_autolens suite green (269/269). Adds 4 numpy-only unit tests in test_tracer_util.py covering _redshift_is_traced detection, partition path input-order preservation, concrete-redshift dedup, and planes_from grouping.
Collaborator
Author
|
Workspace PR: PyAutoLabs/autolens_workspace_test#81 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fix issue #498, reported by @qiuhan96: setting
subhalo.redshift = af.UniformPrior(...)raisedjax.errors.TracerBoolConversionErrorunderjax.jit.The bug:
tracer_util.plane_redshifts_from,tracer_util.planes_from,tracer_util.grid_2d_at_redshift_from, andTracer.galaxies_ascending_redshiftall called Pythonsorted(galaxies, key=lambda g: g.redshift)on a galaxies list. When one of those redshifts was a JAX traced scalar (e.g. a free-parameter subhalo redshift),sortedperformed pairwise<comparisons that JAX cannot lift into traced ops, so the very first call into the multi-plane code path raised.The fix adds a JAX-aware guard. Each function keeps its existing numpy fast-path for the case where every redshift is concrete (byte-for-byte identical behavior — no production fits change). When any redshift is traced, the function falls through to a partition-and-input-order path: concrete redshifts are deduped by exact float equality, traced redshifts each get their own dedicated plane in input position, and
grid_2d_at_redshift_frommatches the requested redshift to a galaxy by Python identity (the only call site,AnalysisLens.tracer_via_instance_from, always passes the subhalo's own redshift object).The integration-test reproducer in autolens_workspace_test PR #79 (Scenario B) now passes —
log_likelihood = -3.523166e+05, JIT path matches NumPy withinrtol=1e-4.lp.pyand the rest of the existing JAX likelihood-function tests are unaffected (regression literals match to withinrtol=1e-4).API Changes
None — internal changes only. All public function signatures are unchanged. Behaviour is unchanged for any model where every galaxy redshift is a concrete Python number; that's the entire
test_autolenscorpus and every workspace example today. Only previously-failing fits with a traced galaxy redshift see a difference: they now succeed.See full details below.
Test Plan
pytest test_autolens/— 269/269 passing on the patched librarylp.pyJAX likelihood-function script —log_likelihood ≈ -1.34797842e+09, matches the existing regression literal atlp.py:201withinrtol=1e-4autolens_workspace_test/scripts/jax_likelihood_functions/imaging/subhalo.py— Scenario A (fixed z=0.55) PASS, Scenario B (freeUniformPrior(0.2, 0.9)) now PASS, both producelog_likelihood = -3.523166e+05with JIT matching NumPy atrtol=1e-4test_tracer_util.pycovering the partition-and-input-order path_any_tracedguard at the top of each function — when all redshifts returnFalsefrom_redshift_is_traced, the function body is identical to the prior implementation)Full API Changes (for automation & release notes)
Removed
None.
Added
autolens/lens/tracer_util.py:_redshift_is_traced(redshift) -> bool— module-private helper. ReturnsTruefor JAX traced scalars (anything that raises onfloat(value)and isn't a concreteint/float/ 0-dnp.ndarray).autolens/lens/tracer_util.py:_any_traced(galaxies) -> bool— module-private helper. ReturnsTrueif any galaxy in the list has a traced redshift.Renamed / Moved
None.
Changed Signature
None.
Changed Behaviour
autolens/lens/tracer_util.plane_redshifts_from(galaxies)— when no redshift is traced, behaviour is identical to before. When any redshift is traced, the function returns plane redshifts in input-list order, deduped only by concrete-equality between concrete redshifts. Each traced-redshift galaxy gets its own plane.autolens/lens/tracer_util.planes_from(galaxies, plane_redshifts=None)— same dual-path treatment. The traced-input branch builds plane groupings from input order rather than sorted order.autolens/lens/tracer_util.grid_2d_at_redshift_from(redshift, galaxies, grid, cosmology, xp)— whenredshiftis traced (or any galaxy has a traced redshift), the function matches the requested redshift to a galaxy by Pythonisidentity (rather than value equality) and returns the multi-plane traced grid at that galaxy's plane. RaisesRayTracingExceptionif the traced redshift doesn't belong to any input galaxy by identity. The existing concrete-only path is unchanged.autolens/lens/tracer.Tracer.galaxies_ascending_redshift— when no redshift is traced, returns the samesorted(...)as before. When any redshift is traced, returnslist(self.galaxies)(input order trusted).Migration
No user-facing migration required. Existing scripts hit the unchanged fast-path. Users wishing to make a
Galaxy.redshifta free parameter viaaf.UniformPrior(...)must declare their galaxies in ascending-redshift order inaf.Collection(galaxies=...)(which the natural declaration order already produces).Notes for reviewers
Collection(galaxies=Collection(lens=..., subhalo=..., source=...))). If a user puts the subhalo before the lens in declaration order with a redshift between them, the multi-plane scaling factors will be computed in the wrong order — but that's a model-construction error even on the numpy path.autogalaxy.cosmologyaccept traced redshifts under JAX. The reproducer reaches them viatraced_grid_2d_list_fromand produces a finite log-likelihood matching NumPy, so they appear to work for Planck15. If a different cosmology needs a JAX-friendly shim, that's a follow-up issue.analysis/lens.py:116(instance.galaxies.subhalo.mass.centre = tuple(subhalo_centre.in_list[0])) was a suspected breakage site but turns out to work fine —tuple((traced_y, traced_x))survives JIT tracing intact. No change needed.🤖 Generated with Claude Code