Skip to content

refactor: port CSE module to support JAX via xp parameter #446

@Jammy2211

Description

@Jammy2211

Overview

Port autogalaxy/profiles/mass/abstract/cse.py to support JAX by threading the xp=np parameter through all methods, mirroring the MGE module (mge.py) which is already JAX-ready. Currently the CSE module uses pure NumPy and scipy.linalg.lstsq, blocking JAX JIT compilation for any profile that uses CSE decomposition. This is Phase 2 of the mass profiles refactoring epic (#445).

Plan

  • Thread xp=np parameter through all CSE static and instance methods
  • Replace hardcoded np.sqrt, np.vstack, np.logspace, np.zeros, np.log10 with xp.*
  • Add xp is not np branch in _decompose_convergence_via_cse_from using jnp.linalg.lstsq (keep scipy.linalg.lstsq for NumPy path)
  • Thread xp=xp through all caller profiles that inherit MassProfileCSE (NFW and other dark matter profiles)
  • Run existing unit tests to verify no regressions on NumPy path
  • Add CSE-based profiles to autolens_workspace_test/scripts/profiles_jit.py JAX three-step tests
Detailed implementation plan

Affected Repositories

  • PyAutoGalaxy (primary)
  • autolens_workspace_test (test additions — follow-up)

Work Classification

Library (then workspace test follow-up)

Branch Survey

Repository Current Branch Dirty?
PyAutoGalaxy main modified CLAUDE.md only

Suggested branch: feature/cse-jax-port
Worktree root: ~/Code/PyAutoLabs-wt/cse-jax-port/

Implementation Steps

  1. autogalaxy/profiles/mass/abstract/cse.py — the core change:

    • convergence_cse_1d_from(grid_radii, core_radius) → add xp=np (pure arithmetic, no np calls to replace)
    • deflections_via_cse_from(...) → add xp=np, replace np.sqrtxp.sqrt, np.vstackxp.vstack
    • _convergence_2d_via_cse_from(grid_radii, **kwargs) → thread xp to convergence_cse_1d_from
    • _deflections_2d_via_cse_from(grid, **kwargs) → thread xp to deflections_via_cse_from, replace np.* with xp.* for grid operations
    • _decompose_convergence_via_cse_from(func, ...) → add xp is not np branch with jnp.linalg.lstsq, keep scipy for NumPy path, replace np.logspace/zeros/log10 with xp.*
  2. Callers in autogalaxy/profiles/mass/dark/:

    • nfw.py — NFW uses CSE for deflections via _deflections_2d_via_cse_from; thread xp=xp
    • Any other dark profiles mixing in MassProfileCSE — thread xp=xp
  3. Unit tests — run pytest test_autogalaxy/profiles/mass/ to verify NumPy path unchanged

  4. Workspace test additions (follow-up PR on autolens_workspace_test):

    • Add NFW CSE path to scripts/profiles_jit.py JAX three-step pattern
    • Run Phase 1 self-consistency tests to confirm no regressions

Key Files

  • autogalaxy/profiles/mass/abstract/cse.py — CSE mixin (6 methods to port)
  • autogalaxy/profiles/mass/dark/nfw.py — primary CSE caller
  • autogalaxy/profiles/mass/dark/abstract.py — DarkProfile base
  • test_autogalaxy/profiles/mass/dark/test_nfw.py — existing NFW tests

Key Constraint

The CSE decomposition (_decompose_convergence_via_cse_from) is a one-time setup computation, not part of the JIT-traced forward pass. It must NOT be called inside jax.jit. The forward methods that consume cached decomposition results ARE traced and must be pure xp code.

Original Prompt

Click to expand starting prompt

Port the CSE (Cored Steep Ellipsoid) module in PyAutoGalaxy to support JAX.

Make autogalaxy/profiles/mass/abstract/cse.py JAX-compatible by threading the xp=np parameter through all methods, mirroring how the MGE module (mge.py) already supports both NumPy and JAX backends. Replace np.* calls with xp.*, add jnp.linalg.lstsq branch for the decomposition solver, and thread xp=xp through all callers (NFW and other dark matter profiles).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions