refactor: port CSE module to support JAX via xp parameter#447
Merged
Conversation
Thread xp=np through all MassProfileCSE methods (convergence_cse_1d_from, deflections_via_cse_from, _convergence_2d_via_cse_from, _deflections_2d_via_cse_from) and replace np.sqrt/np.vstack with xp equivalents. Thread xp=xp through NFW CSE callers. The decomposition solver (_decompose_convergence_via_cse_from) stays NumPy-only as it is a one-time setup computation, not part of the JIT-traced forward pass. Phase 2 of mass profiles refactoring epic (#445). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Port the CSE (Cored Steep Ellipsoid) module to support JAX by threading
xp=npthrough all forward-path methods, mirroring the already JAX-ready MGE module. This unblocks JAX JIT compilation for any profile that uses CSE decomposition (primarily NFW). Phase 2 of the mass profiles refactoring epic (#445).API Changes
Added
xp=npkeyword argument to fourMassProfileCSEmethods:convergence_cse_1d_from,deflections_via_cse_from,_convergence_2d_via_cse_from,_deflections_2d_via_cse_from. Existing callers passing noxpargument are unaffected (default isnp). The decomposition solver (_decompose_convergence_via_cse_from) stays NumPy-only — it runs before JIT tracing. See full details below.Test Plan
pytest test_autogalaxy/profiles/mass/— 406 passed, 0 failedprofiles_jit.pyJAX three-step test for CSE-based profiles (follow-up on autolens_workspace_test)Full API Changes (for automation & release notes)
Changed Signature
MassProfileCSE.convergence_cse_1d_from(grid_radii, core_radius)→convergence_cse_1d_from(grid_radii, core_radius, xp=np)MassProfileCSE.deflections_via_cse_from(term1, term2, term3, term4, axis_ratio_squared, core_radius)→ addsxp=npMassProfileCSE._convergence_2d_via_cse_from(grid_radii, **kwargs)→ addsxp=npMassProfileCSE._deflections_2d_via_cse_from(grid, **kwargs)→ addsxp=npChanged Behaviour
deflections_via_cse_fromnow usesxp.sqrtandxp.vstackinstead of hardcodednp.sqrt/np.vstackNFW.deflections_2d_via_cse_fromandNFW.convergence_2d_via_cse_fromnow threadxp=xpto CSE base methodsMigration
No migration needed — all new
xpparameters default tonp, preserving existing behaviour.🤖 Generated with Claude Code