Skip to content

Commit aee5e94

Browse files
Jammy2211Jammy2211
authored andcommitted
fix(inversion): make AbstractMeshGeometry picklable
Carve-out from PyAutoFit #1279 Q2 (Phase 4 of the JAX visualization roadmap). A picklability spike found FitImaging cannot be pickled today because AbstractMeshGeometry.__init__ stores `self._xp = xp` — the literal numpy or jax.numpy module — and Python's pickle cannot serialise module references. This blocked sending a FitImaging over an mp.Process+Queue or ProcessPoolExecutor boundary, which is the production target for Phase 4 subprocess visualization. Replace the module-attribute pattern with `self._use_jax: bool` + `_xp` as a property derived from that flag. Same pattern already used in Analysis._xp (PyAutoFit) and AbstractMaker._xp (PyAutoArray decorators per CLAUDE.md). All existing `self._xp` reads continue to work transparently via the property. End-to-end verified: a populated FitImaging round-trips through pickle.dumps/loads with log_likelihood delta=0.00e+00 on both numpy and JAX backends. Pickle size ~4.6 MB for a Rectangular-adaptive- density pixelization fit. Closes #320. Carve-out from #1279.
1 parent 41465a3 commit aee5e94

2 files changed

Lines changed: 85 additions & 1 deletion

File tree

autoarray/inversion/mesh/mesh_geometry/abstract.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,11 @@ def __init__(
2020
# When non-None, rectangular geometry uses the spline-CDF helpers
2121
# instead of the linear-interp CDF (areas / edges transforms only).
2222
self.spline_deg = spline_deg
23-
self._xp = xp
23+
self._use_jax = xp is not np
24+
25+
@property
26+
def _xp(self):
27+
if self._use_jax:
28+
import jax.numpy as jnp
29+
return jnp
30+
return np
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Pickle round-trip tests for AbstractMeshGeometry subclasses.
2+
3+
Required by subprocess visualization (PyAutoFit #1279, Phase 4 of the JAX
4+
visualization roadmap). A populated FitImaging must be sendable over an
5+
mp.Process+Queue or ProcessPoolExecutor boundary — historically blocked
6+
by `self._xp = xp` (module attribute) on AbstractMeshGeometry, since
7+
Python's pickle cannot serialise module references.
8+
9+
Fix: `_xp` is now a property derived from a boolean `_use_jax` flag.
10+
This file is the regression test for that invariant.
11+
"""
12+
13+
import importlib.util
14+
import pickle
15+
16+
import numpy as np
17+
import pytest
18+
19+
from autoarray.inversion.mesh.mesh_geometry.rectangular import MeshGeometryRectangular
20+
from autoarray.inversion.mesh.mesh_geometry.delaunay import MeshGeometryDelaunay
21+
22+
23+
def _jax_installed() -> bool:
24+
return importlib.util.find_spec("jax") is not None
25+
26+
27+
@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay])
28+
def test_pickle_round_trip_numpy_backend(cls):
29+
"""A numpy-backed MeshGeometry instance must round-trip through pickle
30+
with `_xp` restored to the numpy module."""
31+
mg = cls.__new__(cls)
32+
mg._use_jax = False
33+
34+
restored = pickle.loads(pickle.dumps(mg))
35+
36+
assert restored._use_jax is False
37+
assert restored._xp is np
38+
39+
40+
@pytest.mark.skipif(not _jax_installed(), reason="jax not installed")
41+
@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay])
42+
def test_pickle_round_trip_jax_backend(cls):
43+
"""A JAX-backed MeshGeometry instance must round-trip through pickle
44+
with `_xp` restored to the jax.numpy module."""
45+
import jax.numpy as jnp
46+
47+
mg = cls.__new__(cls)
48+
mg._use_jax = True
49+
50+
restored = pickle.loads(pickle.dumps(mg))
51+
52+
assert restored._use_jax is True
53+
assert restored._xp is jnp
54+
55+
56+
def test_use_jax_inferred_from_xp_kwarg_in_init():
57+
"""The __init__ continues to accept an `xp=` kwarg but stores it as
58+
a boolean — modules never land on the instance."""
59+
import types
60+
61+
# Use __new__ + manual __init__ call with stubbed positional args to
62+
# avoid pulling in Mesh / Grid construction. The invariant under test
63+
# is post-__init__ state.
64+
mg = MeshGeometryRectangular.__new__(MeshGeometryRectangular)
65+
MeshGeometryRectangular.__init__(
66+
mg,
67+
mesh=None,
68+
mesh_grid=None,
69+
data_grid=None,
70+
xp=np,
71+
)
72+
assert mg._use_jax is False
73+
# No module attribute should exist on the instance.
74+
module_attrs = [
75+
name for name, val in vars(mg).items() if isinstance(val, types.ModuleType)
76+
]
77+
assert module_attrs == [], f"unexpected module attrs on instance: {module_attrs}"

0 commit comments

Comments
 (0)