Skip to content

feat: PointSolver(use_jax=True) + autolens.jax.register_tracer_classes#538

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/simulator-use-jax
May 24, 2026
Merged

feat: PointSolver(use_jax=True) + autolens.jax.register_tracer_classes#538
Jammy2211 merged 1 commit into
mainfrom
feature/simulator-use-jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

First PR of Phase 2 of z_features/jax_user_intro.md. Adds use_jax=True constructor flag to PointSolver, which makes solver.solve(tracer, source_plane_coordinate) JAX-native by default — xp falls back to jnp and remove_infinities defaults to False for static-shape JIT compatibility.

Also adds a new autolens.jax module exposing register_tracer_classes(tracer), the one-line setup users call once before wrapping the solver in their own @jax.jit. The walker mirrors the auto-registration that AnalysisImaging._register_fit_imaging_pytrees() already performs on the modeling path.

Authoritative design doc: admin_jammy/notes/jax_interface.md (admin_jammy main f381393).

API Changes

  • Added: PointSolver(use_jax=False) constructor flag, forwarded by for_grid() and for_limits_and_scale().
  • Added: PointSolver._xp property (returns jnp if use_jax, else np).
  • Changed signature: PointSolver.solve(xp=None, remove_infinities=None) — both default to None, meaning "auto from self.use_jax". Existing explicit values still honoured.
  • Added: autolens.jax module + register_tracer_classes(tracer) public function. Users call once before @jax.jit to register Tracer + Galaxy + concrete profile classes as JAX pytrees.

See full details below.

Test Plan

  • 5 new unit tests in test_autolens/point/triangles/test_use_jax.py — constructor wiring, defaults, tree_flatten/tree_unflatten roundtrip with use_jax, default remove_infinities on NumPy path.
  • Full PyAutoLens test suite passes (317/317, no regressions).
  • Cross-xp numerical parity script at autolens_workspace_test/scripts/point_source/solver_use_jax_parity.py confirms NumPy and JAX positions agree to atol=1e-3, including inside @jax.jit. Script ships in a follow-up workspace_test PR since it depends on this library API landing first.
Full API Changes (for automation & release notes)

Added

  • autolens.jax (new module).
  • autolens.jax.register_tracer_classes(tracer) — register every concrete class reachable from tracer (Galaxy + light/mass/point profiles + Tracer itself) as a JAX pytree. Idempotent. No-op if JAX not installed.
  • PointSolver(use_jax=False) constructor parameter (defined on AbstractSolver).
  • PointSolver._xp property — returns jax.numpy if self.use_jax, else numpy.
  • PointSolver.for_grid(use_jax=False) and PointSolver.for_limits_and_scale(use_jax=False) — forwarding.

Changed signature

  • PointSolver.solve(tracer, source_plane_coordinate, xp=None, plane_redshift=None, remove_infinities=None)xp and remove_infinities now default to None. Behaviour:
    • xp=None → uses self._xp (i.e., jnp if constructed with use_jax=True, else np).
    • remove_infinities=None → defaults to not self.use_jax (strips inf rows on NumPy path; keeps them on JAX path for static-shape JIT compatibility).
    • Explicit xp=... and remove_infinities=... are still honoured.

Changed behaviour

  • AbstractSolver.tree_flatten / tree_unflatten now include use_jax in aux_data so the flag round-trips across JAX pytree boundaries.

Migration

  • No breaking changes. Existing call sites (solver.solve(tracer, coord), for_grid(...)) continue to work unchanged — use_jax defaults to False.
  • New canonical pattern for JIT'd point-source solves:
    import jax
    import jax.numpy as jnp
    from autolens.jax import register_tracer_classes
    
    register_tracer_classes(tracer)  # ONE-TIME: register Tracer + Galaxy + profiles
    
    solver = al.PointSolver.for_grid(
        grid=grid, pixel_scale_precision=0.001, use_jax=True
    )
    
    @jax.jit
    def jitted_solve(tracer, coord):
        return solver.solve(tracer=tracer, source_plane_coordinate=coord).array
    
    positions = jitted_solve(tracer, jnp.asarray(source_coord))
    Cluster-style manual ceremony (the existing cluster/simulator.py pattern of register_instance_pytree(Tracer, ...) + register_model(af.Collection(...))) becomes unnecessary once Phase 2 PRs 2-4 land.

Out of scope (separate PRs)

  • SimulatorImaging.use_jax=True + xp-conversion of simulator internals (next PR).
  • SimulatorInterferometer.use_jax=True + via_image_from signature fix.
  • xp=np + jnp-backed-grid mismatch ValueError in AbstractMaker.__init__ (PyAutoArray-only).

🤖 Generated with Claude Code

Adds use_jax constructor flag to PointSolver (via AbstractSolver). When True,
solve() defaults xp=jnp and remove_infinities=False, honouring the JAX
static-shape contract for jit-compatible point-source solves.

Also adds autolens.jax module with register_tracer_classes(tracer), the
explicit pytree-registration entry point users call once before wrapping
solver.solve in their own @jax.jit. The walker mirrors the existing
AnalysisImaging._register_fit_imaging_pytrees pattern for the Tracer +
Galaxy + profile classes a concrete tracer carries.

First PR of Phase 2 in z_features/jax_user_intro.md. Follow-ups:
SimulatorImaging.use_jax=True, SimulatorInterferometer.use_jax=True, and the
xp=np + jnp-grid mismatch ValueError land in subsequent PRs.

Design doc: admin_jammy/notes/jax_interface.md
Issue: PyAutoArray#334

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@Jammy2211 Jammy2211 added the pending-release PR queued for the next release build label May 24, 2026
@Jammy2211 Jammy2211 merged commit a90f888 into main May 24, 2026
6 checks passed
@Jammy2211 Jammy2211 deleted the feature/simulator-use-jax branch May 24, 2026 15:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant