refactor(point): make PointSolver stateless w.r.t. xp#469
Merged
Conversation
PointSolver no longer locks in numpy or JAX at construction time. The array module is chosen per-call by passing xp= to solve() / solve_triangles(), so a single solver works with both AnalysisPoint(use_jax=False) and AnalysisPoint(use_jax=True) without needing a second JAX-configured solver. - Drop xp= from PointSolver.for_grid and for_limits_and_scale. - Drop self.initial_triangles / self.use_jax / self._xp from AbstractSolver; store the image-plane extent primitives instead and rebuild triangles on each solve via _initial_triangles(xp). - Thread xp through solve / solve_triangles / steps / _plane_grid / _plane_triangles / _filter_low_magnification. - Update pytree flatten/unflatten to serialize the geometry primitives. - Update internal call sites in autolens.analysis.result and autolens.point.fit.positions.image.abstract to pass xp at solve-time. - MockPointSolver.solve() gains the xp= keyword to match the new signature. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This was referenced Apr 21, 2026
Collaborator
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Removes the long-standing double-knob on the point-source JAX path: previously a user had to both pass
xp=jnptoPointSolver.for_gridanduse_jax=TruetoAnalysisPoint. Now the solver is stateless w.r.t. the array module — one solver construction works for both NumPy and JAX, andAnalysisPoint(use_jax=True)alone suffices.Closes #466.
API Changes
PointSolver.for_gridandPointSolver.for_limits_and_scaleno longer acceptxp=. The array module is now passed per-call tosolver.solve(xp=...)/solver.solve_triangles(xp=...). Most user code never calls.solve()directly —AnalysisPointthreadsxpthrough internally based onuse_jax=.No backwards-compatibility shim. Workspace scripts that still pass
xp=jnptoPointSolver.for_gridwill raiseTypeError; see the companion workspace PR for the full migration.See full details below.
Test Plan
pytest test_autolens/point/passes (48 tests, NumPy paths only — per the JAX-free unit test policy)autolens_workspace_test/scripts/jax_likelihood_functions/point_source/{image_plane,point,source_plane}.pyrun end-to-end and NumPy log-likelihood matches the JIT-compiled log-likelihood (already verified by upstream session)Full API Changes (for automation & release notes)
Removed
PointSolver.for_grid(..., xp=...)—xpkeyword removedPointSolver.for_limits_and_scale(..., xp=...)—xpkeyword removedAbstractSolver(initial_triangles=...)— constructor no longer takes pre-built trianglesAbstractSolver.use_jaxattributeAbstractSolver._xppropertyAdded
AbstractSolver(y_min, y_max, x_min, x_max, ...)— constructor now takes geometry primitivesAbstractSolver._initial_triangles(xp)— builds triangles for the requested array module on demandPointSolver.solve(..., xp=np, ...)— new keyword, defaults to NumPyShapeSolver.find_magnification(..., xp=np, ...)— new keyword, defaults to NumPyMockPointSolver.solve(..., xp=np, ...)— mirrors real solver signatureChanged Signature
AbstractSolver.solve_triangles,AbstractSolver.steps,AbstractSolver._plane_grid,AbstractSolver._plane_triangles,AbstractSolver._filter_low_magnification— now take a requiredxpargumentPointSolver/AbstractSolverpytree:tree_flatten/tree_unflattennow serialize 8 geometry primitives instead of a pre-builtinitial_trianglesobjectMigration
solver.solve(tracer, source_plane_coordinate)call is unchanged for NumPy usage. For JAX, passxp=jnpexplicitly:solver.solve(tracer, coord, xp=jnp).🤖 Generated with Claude Code