diff --git a/autolens/jax/__init__.py b/autolens/jax/__init__.py new file mode 100644 index 000000000..0a03f2a4e --- /dev/null +++ b/autolens/jax/__init__.py @@ -0,0 +1 @@ +from autolens.jax.registration import register_tracer_classes diff --git a/autolens/jax/registration.py b/autolens/jax/registration.py new file mode 100644 index 000000000..78573fa01 --- /dev/null +++ b/autolens/jax/registration.py @@ -0,0 +1,106 @@ +"""Lazy JAX pytree registration for `Tracer` and the classes it contains. + +When a user wraps a PyAutoLens call in their own ``@jax.jit`` and the call +receives a ``Tracer`` as a traced argument, every concrete class reachable +from the tracer must be registered as a JAX pytree node so JAX can flatten +and unflatten across the JIT boundary. + +This module is the counterpart of ``AnalysisImaging._register_fit_imaging_pytrees`` +for code paths that do not go through ``Analysis`` (point-source solving, +custom forward models, hand-built simulators). It is called automatically +by ``PointSolver(use_jax=True).solve(tracer, ...)`` on the first invocation +and by ``Simulator(use_jax=True).via_tracer_from(tracer, ...)`` in PyAutoLens +once Phase 2 ships the Simulator changes. + +Mirrors PyAutoFit's ``autofit/jax/pytrees.py`` layout. Idempotent: re-registration +of a class is a silent no-op. +""" +from typing import Iterable + + +def register_tracer_classes(tracer) -> bool: + """Register every concrete class reachable from ``tracer`` as a JAX pytree. + + Walks ``tracer.galaxies`` and registers ``Galaxy`` plus each light / + mass / point profile class encountered. Also registers ``Tracer`` + itself with ``no_flatten=("cosmology",)`` so the cosmology rides as + aux data across the JIT boundary (it is a per-fit constant). + + Returns ``True`` if registration ran (or was already complete), + ``False`` if JAX is not installed (in which case the call is a silent + no-op). + """ + try: + import jax # noqa: F401 + except ImportError: + return False + + from autoarray.abstract_ndarray import register_instance_pytree + from autolens.lens.tracer import Tracer + + register_instance_pytree(Tracer, no_flatten=("cosmology",)) + + for galaxy in tracer.galaxies: + _register_object_classes(galaxy) + + return True + + +def _register_object_classes(obj) -> None: + """Walk an object recursively and register each non-builtin class it carries. + + Used to register ``Galaxy`` plus every concrete profile class + (``Sersic``, ``Isothermal``, ``NFW``, ``Point``, ...) the galaxy holds. + Skips builtin types (numbers, strings, sequences) since those are not + user classes that JAX needs flatten/unflatten functions for. + """ + from autoarray.abstract_ndarray import register_instance_pytree + + cls = type(obj) + if _is_builtin(cls): + return + + register_instance_pytree(cls) + + for value in _iter_attribute_values(obj): + _register_object_classes(value) + + +def _iter_attribute_values(obj) -> Iterable: + """Yield each attribute value of ``obj`` worth recursing into. + + Walks ``obj.__dict__`` (if present) and recurses one level into list / + tuple / dict containers to reach profile objects held in collections. + """ + if not hasattr(obj, "__dict__"): + return + + for value in vars(obj).values(): + if isinstance(value, (list, tuple)): + for item in value: + yield item + elif isinstance(value, dict): + for item in value.values(): + yield item + else: + yield value + + +def _is_builtin(cls) -> bool: + """True for primitive / container / standard-library / numerical-backend + types that should not be registered as JAX pytrees. + + Catches the JAX tracer types (``DynamicJaxprTracer`` et al.) explicitly: + if a walker happens to recurse into JAX-traced state (because it ran + inside a ``jax.jit`` trace), registering ``type(tracer)`` would make every + subsequent tracer-flatten call route into our dict-based flatten, which + fails because tracers do not have ``__dict__``. + """ + if cls is type(None): + return True + module = cls.__module__ + if module == "builtins": + return True + if module.startswith(("numpy", "jax", "jaxlib")): + return True + return False diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 526823680..30bb8e24c 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -37,9 +37,9 @@ def solve( self, tracer: Tracer, source_plane_coordinate: Tuple[float, float], - xp=np, + xp=None, plane_redshift: Optional[float] = None, - remove_infinities: bool = True, + remove_infinities: Optional[bool] = None, ) -> aa.Grid2DIrregular: """ Solve for the image plane coordinates that are traced to the source plane coordinate. @@ -66,14 +66,21 @@ def solve( but could be a coordinate in another plane is `plane_redshift` is input. xp The array module (``numpy`` or ``jax.numpy``) the solve runs in. ``AnalysisPoint`` - passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis. + passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis. When ``None`` (the + default), falls back to ``self._xp`` — which is ``jnp`` if the solver was constructed + with ``use_jax=True`` and ``np`` otherwise. Pass explicitly to override. plane_redshift The redshift of the plane coordinate, which for multi-plane systems may not be the source-plane. + remove_infinities + Whether to strip the ``inf`` sentinel rows from the output. When ``None`` (the default), + defaults to ``True`` on the NumPy path and ``False`` on the JAX path. The JAX path + keeps the padded static shape so the output crosses a ``jax.jit`` boundary cleanly; + strip the infinities outside the jit if needed. Returns ------- - A ``Grid2DIrregular`` of image-plane coordinates, always numpy-backed even when the - solver uses a JAX backend internally. + A ``Grid2DIrregular`` of image-plane coordinates. NumPy-backed on the default path, + ``jax.Array``-backed when ``use_jax=True`` (or ``xp=jnp``). Notes ----- @@ -87,6 +94,20 @@ def solve( inside a ``jax.jit`` trace, so a plain numpy-backed ``Grid2DIrregular`` is safe here even when the surrounding analysis uses ``xp=jnp``. """ + if xp is None: + xp = self._xp + + if remove_infinities is None: + remove_infinities = not self.use_jax + + # NOTE: pytree registration is the user's responsibility (call + # `autolens.jax.register_tracer_classes(tracer)` once before wrapping + # in @jax.jit). Auto-registering inside solve() doesn't help because + # JAX flattens function arguments at trace time — before entering + # this method — so registration must run before the first jitted + # call. See the `lens_calc.py` workspace guide for the canonical + # JIT-it-yourself pattern. + if os.environ.get("PYAUTO_SMALL_DATASETS") == "1": return aa.Grid2DIrregular(values=[(1.0, 0.0), (0.0, 1.0)]) diff --git a/autolens/point/solver/shape_solver.py b/autolens/point/solver/shape_solver.py index b006cc7db..dc58886cd 100644 --- a/autolens/point/solver/shape_solver.py +++ b/autolens/point/solver/shape_solver.py @@ -42,6 +42,7 @@ def __init__( pixel_scale_precision: float, magnification_threshold=0.1, neighbor_degree: int = 1, + use_jax: bool = False, ): """ Determine the image plane coordinates that are traced to be a source plane coordinate. @@ -67,6 +68,12 @@ def __init__( neighbor_degree The number of times recursively add neighbors for the triangles that contain the source plane coordinate. + use_jax + If ``True``, ``.solve()`` defaults to ``xp=jnp`` and ``remove_infinities=False`` + (the JAX-static-shape contract), and registers ``Tracer`` plus the concrete + galaxy / profile classes it carries as JAX pytrees on the first call. The + user wraps the call in their own ``@jax.jit`` — see the ``lens_calc.py`` + workspace guide for the canonical pattern. """ self.y_min = y_min self.y_max = y_max @@ -76,6 +83,18 @@ def __init__( self.pixel_scale_precision = pixel_scale_precision self.magnification_threshold = magnification_threshold self.neighbor_degree = neighbor_degree + self.use_jax = use_jax + + @property + def _xp(self): + """The array module the solver runs against by default. ``jnp`` when + ``use_jax=True``, ``np`` otherwise. ``.solve()`` falls back to this when + the caller does not pass ``xp=`` explicitly.""" + if self.use_jax: + import jax.numpy as jnp + + return jnp + return np def _initial_triangles(self, xp): """ @@ -107,6 +126,7 @@ def for_grid( pixel_scale_precision: float, magnification_threshold=0.1, neighbor_degree: int = 1, + use_jax: bool = False, ): """ Create a solver for a given grid. @@ -123,6 +143,8 @@ def for_grid( The threshold for the magnification under which multiple images are filtered. neighbor_degree The number of times recursively add neighbors for the triangles that contain + use_jax + Forwarded to the constructor; see ``__init__``. Returns ------- @@ -142,6 +164,7 @@ def for_grid( pixel_scale_precision=pixel_scale_precision, magnification_threshold=magnification_threshold, neighbor_degree=neighbor_degree, + use_jax=use_jax, ) @classmethod @@ -155,6 +178,7 @@ def for_limits_and_scale( pixel_scale_precision: float = 0.001, magnification_threshold=0.1, neighbor_degree: int = 1, + use_jax: bool = False, ): """ Create a solver for an explicit image-plane extent. @@ -171,6 +195,8 @@ def for_limits_and_scale( The threshold for the magnification under which multiple images are filtered. neighbor_degree The number of times recursively add neighbors for the triangles that contain + use_jax + Forwarded to the constructor; see ``__init__``. Returns ------- @@ -185,6 +211,7 @@ def for_limits_and_scale( pixel_scale_precision=pixel_scale_precision, magnification_threshold=magnification_threshold, neighbor_degree=neighbor_degree, + use_jax=use_jax, ) @property @@ -384,6 +411,7 @@ def tree_flatten(self): self.pixel_scale_precision, self.magnification_threshold, self.neighbor_degree, + self.use_jax, ) @classmethod @@ -397,6 +425,7 @@ def tree_unflatten(cls, aux_data, children): pixel_scale_precision, magnification_threshold, neighbor_degree, + use_jax, ) = aux_data return cls( y_min=y_min, @@ -407,6 +436,7 @@ def tree_unflatten(cls, aux_data, children): pixel_scale_precision=pixel_scale_precision, magnification_threshold=magnification_threshold, neighbor_degree=neighbor_degree, + use_jax=use_jax, ) diff --git a/test_autolens/point/triangles/test_use_jax.py b/test_autolens/point/triangles/test_use_jax.py new file mode 100644 index 000000000..2aa2e60f8 --- /dev/null +++ b/test_autolens/point/triangles/test_use_jax.py @@ -0,0 +1,73 @@ +"""Unit tests for the ``PointSolver(use_jax=True)`` constructor wiring. + +Per the PyAutoArray dependency-graph rule, library unit tests stay NumPy-only — +cross-xp numerical parity for the actual JAX execution path lives in +``autolens_workspace_test/scripts/point_source/solver_use_jax_parity.py``. +These tests cover only the constructor wiring, the default fallbacks +(``xp=None``, ``remove_infinities=None``), and the pytree-tree-flatten roundtrip. +""" +import numpy as np + +import autolens as al + + +def test_use_jax_defaults_false(): + solver = al.PointSolver.for_grid( + grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0), + pixel_scale_precision=0.01, + ) + assert solver.use_jax is False + assert solver._xp is np + + +def test_use_jax_flag_threads_through_for_grid(): + solver = al.PointSolver.for_grid( + grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0), + pixel_scale_precision=0.01, + use_jax=True, + ) + assert solver.use_jax is True + + +def test_use_jax_flag_threads_through_for_limits_and_scale(): + solver = al.PointSolver.for_limits_and_scale( + y_min=-1.0, + y_max=1.0, + x_min=-1.0, + x_max=1.0, + scale=0.1, + pixel_scale_precision=0.01, + use_jax=True, + ) + assert solver.use_jax is True + + +def test_tree_flatten_roundtrips_use_jax(): + solver = al.PointSolver.for_grid( + grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0), + pixel_scale_precision=0.01, + use_jax=True, + ) + _children, aux = solver.tree_flatten() + rebuilt = al.PointSolver.tree_unflatten(aux, _children) + assert rebuilt.use_jax is True + assert rebuilt.y_min == solver.y_min + assert rebuilt.scale == solver.scale + + +def test_solve_default_remove_infinities_numpy(grid): + """On the NumPy path the solve() default still removes infinities (back-compat).""" + solver = al.PointSolver.for_grid( + grid=grid, pixel_scale_precision=0.5, magnification_threshold=1e-8 + ) + tracer = al.Tracer( + galaxies=[ + al.Galaxy( + redshift=0.5, + mass=al.mp.Isothermal(centre=(0.0, 0.0), einstein_radius=1.0), + ) + ] + ) + result = solver.solve(tracer, source_plane_coordinate=(0.0, 0.0)) + # Removing infinities means no row should contain inf on the default numpy path. + assert not np.isinf(np.asarray(result.array)).any()