|
| 1 | +"""Pickle round-trip tests for AbstractMeshGeometry subclasses. |
| 2 | +
|
| 3 | +Required by subprocess visualization (PyAutoFit #1279, Phase 4 of the JAX |
| 4 | +visualization roadmap). A populated FitImaging must be sendable over an |
| 5 | +mp.Process+Queue or ProcessPoolExecutor boundary — historically blocked |
| 6 | +by `self._xp = xp` (module attribute) on AbstractMeshGeometry, since |
| 7 | +Python's pickle cannot serialise module references. |
| 8 | +
|
| 9 | +Fix: `_xp` is now a property derived from a boolean `_use_jax` flag. |
| 10 | +This file is the regression test for that invariant. |
| 11 | +""" |
| 12 | + |
| 13 | +import importlib.util |
| 14 | +import pickle |
| 15 | + |
| 16 | +import numpy as np |
| 17 | +import pytest |
| 18 | + |
| 19 | +from autoarray.inversion.mesh.mesh_geometry.rectangular import MeshGeometryRectangular |
| 20 | +from autoarray.inversion.mesh.mesh_geometry.delaunay import MeshGeometryDelaunay |
| 21 | + |
| 22 | + |
| 23 | +def _jax_installed() -> bool: |
| 24 | + return importlib.util.find_spec("jax") is not None |
| 25 | + |
| 26 | + |
| 27 | +@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay]) |
| 28 | +def test_pickle_round_trip_numpy_backend(cls): |
| 29 | + """A numpy-backed MeshGeometry instance must round-trip through pickle |
| 30 | + with `_xp` restored to the numpy module.""" |
| 31 | + mg = cls.__new__(cls) |
| 32 | + mg._use_jax = False |
| 33 | + |
| 34 | + restored = pickle.loads(pickle.dumps(mg)) |
| 35 | + |
| 36 | + assert restored._use_jax is False |
| 37 | + assert restored._xp is np |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.skipif(not _jax_installed(), reason="jax not installed") |
| 41 | +@pytest.mark.parametrize("cls", [MeshGeometryRectangular, MeshGeometryDelaunay]) |
| 42 | +def test_pickle_round_trip_jax_backend(cls): |
| 43 | + """A JAX-backed MeshGeometry instance must round-trip through pickle |
| 44 | + with `_xp` restored to the jax.numpy module.""" |
| 45 | + import jax.numpy as jnp |
| 46 | + |
| 47 | + mg = cls.__new__(cls) |
| 48 | + mg._use_jax = True |
| 49 | + |
| 50 | + restored = pickle.loads(pickle.dumps(mg)) |
| 51 | + |
| 52 | + assert restored._use_jax is True |
| 53 | + assert restored._xp is jnp |
| 54 | + |
| 55 | + |
| 56 | +def test_use_jax_inferred_from_xp_kwarg_in_init(): |
| 57 | + """The __init__ continues to accept an `xp=` kwarg but stores it as |
| 58 | + a boolean — modules never land on the instance.""" |
| 59 | + import types |
| 60 | + |
| 61 | + # Use __new__ + manual __init__ call with stubbed positional args to |
| 62 | + # avoid pulling in Mesh / Grid construction. The invariant under test |
| 63 | + # is post-__init__ state. |
| 64 | + mg = MeshGeometryRectangular.__new__(MeshGeometryRectangular) |
| 65 | + MeshGeometryRectangular.__init__( |
| 66 | + mg, |
| 67 | + mesh=None, |
| 68 | + mesh_grid=None, |
| 69 | + data_grid=None, |
| 70 | + xp=np, |
| 71 | + ) |
| 72 | + assert mg._use_jax is False |
| 73 | + # No module attribute should exist on the instance. |
| 74 | + module_attrs = [ |
| 75 | + name for name, val in vars(mg).items() if isinstance(val, types.ModuleType) |
| 76 | + ] |
| 77 | + assert module_attrs == [], f"unexpected module attrs on instance: {module_attrs}" |
0 commit comments