fix(inversion): make AbstractMeshGeometry picklable (xp module → _use_jax bool + property)#321
Merged
Merged
Conversation
Carve-out from PyAutoFit #1279 Q2 (Phase 4 of the JAX visualization roadmap). A picklability spike found FitImaging cannot be pickled today because AbstractMeshGeometry.__init__ stores `self._xp = xp` — the literal numpy or jax.numpy module — and Python's pickle cannot serialise module references. This blocked sending a FitImaging over an mp.Process+Queue or ProcessPoolExecutor boundary, which is the production target for Phase 4 subprocess visualization. Replace the module-attribute pattern with `self._use_jax: bool` + `_xp` as a property derived from that flag. Same pattern already used in Analysis._xp (PyAutoFit) and AbstractMaker._xp (PyAutoArray decorators per CLAUDE.md). All existing `self._xp` reads continue to work transparently via the property. End-to-end verified: a populated FitImaging round-trips through pickle.dumps/loads with log_likelihood delta=0.00e+00 on both numpy and JAX backends. Pickle size ~4.6 MB for a Rectangular-adaptive- density pixelization fit. Closes #320. Carve-out from #1279.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Phase 4 subprocess visualization (PyAutoFit #1279) requires sending
FitImaginginstances over an IPC boundary. A spike against the currentFitImagingshowed it's unpicklable:AbstractMeshGeometry.__init__storesself._xp = xp(the literalnumpyorjax.numpymodule) and Python can't pickle modules. This was the only pickle barrier — the rest of the fit graph (mappers, linear operators, computed matrices, tracer, dataset, mask) serialises fine.Fix: replace the module-attribute pattern with the
_use_jax: bool+_xpproperty pattern already used inAnalysis._xp(PyAutoFit) andAbstractMaker._xp(PyAutoArray decorators, perCLAUDE.md). Eliminates the pickle barrier at the API level instead of working around it with__getstate__/__setstate__hooks.API Changes
AbstractMeshGeometry.__init__(xp=np)continues to accept the samexp=kwarg — no caller changes needed._xpas a module attribute; it holds_use_jax: booland exposes_xpas a@propertythat returnsnumpyorjax.numpyon demand.self._xpreads continue to work transparently (property has the same interface as attribute access).See full details below.
Test Plan
test_autoarray/inversion/pixelization/mesh_geometry/test_picklability.py— 5 new tests covering numpy + JAX backends acrossMeshGeometryRectangularandMeshGeometryDelaunay, plus a__init__-state invariant test.test_autoarray/inversion/pixelization/mesh_geometry/tests still pass.test_autoarray/inversion/suite passes (171/171).FitImaging(Sersic lens + Rectangular-adaptive-density pixelization source) round-trips throughpickle.dumps/loadswithlog_likelihoodΔ=0.00e+00 on both numpy and JAX backends. Pickle size 4637.7 KB.Full API Changes (for automation & release notes)
Changed Behaviour
autoarray.inversion.mesh.mesh_geometry.abstract.AbstractMeshGeometry:_xpremoved; replaced by_use_jax: bool(set from thexp=kwarg asxp is not np)._xpis now a@propertythat returnsnumpyif_use_jax is Falseelsejax.numpy. JAX is imported lazily on access.Migration
xp=kwarg on__init__is unchanged. Any reader ofinstance._xpcontinues to work (property returns the same module the old attribute held).instance._xp = ...(none currently exist in the codebase) would now fail at runtime. Confirmed via grep: zero writes outside__init__.Why this matters
FitImagingnow pickles cleanly, unblocking PyAutoFit #1279 (Phase 4 subprocess visualization). The single-class fix on the abstract parent covers bothMeshGeometryRectangularandMeshGeometryDelaunay.🤖 Generated with Claude Code