Skip to content

refactor(point): make PointSolver stateless w.r.t. xp#469

Merged
Jammy2211 merged 1 commit into
mainfrom
feature/point-solver-auto-jax
Apr 21, 2026
Merged

refactor(point): make PointSolver stateless w.r.t. xp#469
Jammy2211 merged 1 commit into
mainfrom
feature/point-solver-auto-jax

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

Removes the long-standing double-knob on the point-source JAX path: previously a user had to both pass xp=jnp to PointSolver.for_grid and use_jax=True to AnalysisPoint. Now the solver is stateless w.r.t. the array module — one solver construction works for both NumPy and JAX, and AnalysisPoint(use_jax=True) alone suffices.

Closes #466.

API Changes

PointSolver.for_grid and PointSolver.for_limits_and_scale no longer accept xp=. The array module is now passed per-call to solver.solve(xp=...) / solver.solve_triangles(xp=...). Most user code never calls .solve() directly — AnalysisPoint threads xp through internally based on use_jax=.

No backwards-compatibility shim. Workspace scripts that still pass xp=jnp to PointSolver.for_grid will raise TypeError; see the companion workspace PR for the full migration.

See full details below.

Test Plan

  • pytest test_autolens/point/ passes (48 tests, NumPy paths only — per the JAX-free unit test policy)
  • autolens_workspace_test/scripts/jax_likelihood_functions/point_source/{image_plane,point,source_plane}.py run end-to-end and NumPy log-likelihood matches the JIT-compiled log-likelihood (already verified by upstream session)
Full API Changes (for automation & release notes)

Removed

  • PointSolver.for_grid(..., xp=...)xp keyword removed
  • PointSolver.for_limits_and_scale(..., xp=...)xp keyword removed
  • AbstractSolver(initial_triangles=...) — constructor no longer takes pre-built triangles
  • AbstractSolver.use_jax attribute
  • AbstractSolver._xp property

Added

  • AbstractSolver(y_min, y_max, x_min, x_max, ...) — constructor now takes geometry primitives
  • AbstractSolver._initial_triangles(xp) — builds triangles for the requested array module on demand
  • PointSolver.solve(..., xp=np, ...) — new keyword, defaults to NumPy
  • ShapeSolver.find_magnification(..., xp=np, ...) — new keyword, defaults to NumPy
  • MockPointSolver.solve(..., xp=np, ...) — mirrors real solver signature

Changed Signature

  • AbstractSolver.solve_triangles, AbstractSolver.steps, AbstractSolver._plane_grid, AbstractSolver._plane_triangles, AbstractSolver._filter_low_magnification — now take a required xp argument
  • PointSolver / AbstractSolver pytree: tree_flatten / tree_unflatten now serialize 8 geometry primitives instead of a pre-built initial_triangles object

Migration

  • Before:
    solver = al.PointSolver.for_grid(grid=grid, pixel_scale_precision=0.001, xp=jnp)
    analysis = al.AnalysisPoint(dataset=dataset, solver=solver, use_jax=True)
  • After:
    solver = al.PointSolver.for_grid(grid=grid, pixel_scale_precision=0.001)
    analysis = al.AnalysisPoint(dataset=dataset, solver=solver, use_jax=True)
  • The direct solver.solve(tracer, source_plane_coordinate) call is unchanged for NumPy usage. For JAX, pass xp=jnp explicitly: solver.solve(tracer, coord, xp=jnp).

🤖 Generated with Claude Code

PointSolver no longer locks in numpy or JAX at construction time. The array
module is chosen per-call by passing xp= to solve() / solve_triangles(), so
a single solver works with both AnalysisPoint(use_jax=False) and
AnalysisPoint(use_jax=True) without needing a second JAX-configured solver.

- Drop xp= from PointSolver.for_grid and for_limits_and_scale.
- Drop self.initial_triangles / self.use_jax / self._xp from AbstractSolver;
  store the image-plane extent primitives instead and rebuild triangles on
  each solve via _initial_triangles(xp).
- Thread xp through solve / solve_triangles / steps / _plane_grid /
  _plane_triangles / _filter_low_magnification.
- Update pytree flatten/unflatten to serialize the geometry primitives.
- Update internal call sites in autolens.analysis.result and
  autolens.point.fit.positions.image.abstract to pass xp at solve-time.
- MockPointSolver.solve() gains the xp= keyword to match the new signature.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit 6c3951d into main Apr 21, 2026
5 checks passed
@Jammy2211 Jammy2211 deleted the feature/point-solver-auto-jax branch April 21, 2026 13:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending-release PR queued for the next release build

Projects

None yet

Development

Successfully merging this pull request may close these issues.

refactor: AnalysisPoint auto-configures PointSolver when use_jax=True

1 participant