Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions autolens/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from autolens.jax.registration import register_tracer_classes
106 changes: 106 additions & 0 deletions autolens/jax/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""Lazy JAX pytree registration for `Tracer` and the classes it contains.

When a user wraps a PyAutoLens call in their own ``@jax.jit`` and the call
receives a ``Tracer`` as a traced argument, every concrete class reachable
from the tracer must be registered as a JAX pytree node so JAX can flatten
and unflatten across the JIT boundary.

This module is the counterpart of ``AnalysisImaging._register_fit_imaging_pytrees``
for code paths that do not go through ``Analysis`` (point-source solving,
custom forward models, hand-built simulators). It is called automatically
by ``PointSolver(use_jax=True).solve(tracer, ...)`` on the first invocation
and by ``Simulator(use_jax=True).via_tracer_from(tracer, ...)`` in PyAutoLens
once Phase 2 ships the Simulator changes.

Mirrors PyAutoFit's ``autofit/jax/pytrees.py`` layout. Idempotent: re-registration
of a class is a silent no-op.
"""
from typing import Iterable


def register_tracer_classes(tracer) -> bool:
"""Register every concrete class reachable from ``tracer`` as a JAX pytree.

Walks ``tracer.galaxies`` and registers ``Galaxy`` plus each light /
mass / point profile class encountered. Also registers ``Tracer``
itself with ``no_flatten=("cosmology",)`` so the cosmology rides as
aux data across the JIT boundary (it is a per-fit constant).

Returns ``True`` if registration ran (or was already complete),
``False`` if JAX is not installed (in which case the call is a silent
no-op).
"""
try:
import jax # noqa: F401
except ImportError:
return False

from autoarray.abstract_ndarray import register_instance_pytree
from autolens.lens.tracer import Tracer

register_instance_pytree(Tracer, no_flatten=("cosmology",))

for galaxy in tracer.galaxies:
_register_object_classes(galaxy)

return True


def _register_object_classes(obj) -> None:
"""Walk an object recursively and register each non-builtin class it carries.

Used to register ``Galaxy`` plus every concrete profile class
(``Sersic``, ``Isothermal``, ``NFW``, ``Point``, ...) the galaxy holds.
Skips builtin types (numbers, strings, sequences) since those are not
user classes that JAX needs flatten/unflatten functions for.
"""
from autoarray.abstract_ndarray import register_instance_pytree

cls = type(obj)
if _is_builtin(cls):
return

register_instance_pytree(cls)

for value in _iter_attribute_values(obj):
_register_object_classes(value)


def _iter_attribute_values(obj) -> Iterable:
"""Yield each attribute value of ``obj`` worth recursing into.

Walks ``obj.__dict__`` (if present) and recurses one level into list /
tuple / dict containers to reach profile objects held in collections.
"""
if not hasattr(obj, "__dict__"):
return

for value in vars(obj).values():
if isinstance(value, (list, tuple)):
for item in value:
yield item
elif isinstance(value, dict):
for item in value.values():
yield item
else:
yield value


def _is_builtin(cls) -> bool:
"""True for primitive / container / standard-library / numerical-backend
types that should not be registered as JAX pytrees.

Catches the JAX tracer types (``DynamicJaxprTracer`` et al.) explicitly:
if a walker happens to recurse into JAX-traced state (because it ran
inside a ``jax.jit`` trace), registering ``type(tracer)`` would make every
subsequent tracer-flatten call route into our dict-based flatten, which
fails because tracers do not have ``__dict__``.
"""
if cls is type(None):
return True
module = cls.__module__
if module == "builtins":
return True
if module.startswith(("numpy", "jax", "jaxlib")):
return True
return False
31 changes: 26 additions & 5 deletions autolens/point/solver/point_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ def solve(
self,
tracer: Tracer,
source_plane_coordinate: Tuple[float, float],
xp=np,
xp=None,
plane_redshift: Optional[float] = None,
remove_infinities: bool = True,
remove_infinities: Optional[bool] = None,
) -> aa.Grid2DIrregular:
"""
Solve for the image plane coordinates that are traced to the source plane coordinate.
Expand All @@ -66,14 +66,21 @@ def solve(
but could be a coordinate in another plane is `plane_redshift` is input.
xp
The array module (``numpy`` or ``jax.numpy``) the solve runs in. ``AnalysisPoint``
passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis.
passes ``jax.numpy`` when ``use_jax=True`` is set on the analysis. When ``None`` (the
default), falls back to ``self._xp`` — which is ``jnp`` if the solver was constructed
with ``use_jax=True`` and ``np`` otherwise. Pass explicitly to override.
plane_redshift
The redshift of the plane coordinate, which for multi-plane systems may not be the source-plane.
remove_infinities
Whether to strip the ``inf`` sentinel rows from the output. When ``None`` (the default),
defaults to ``True`` on the NumPy path and ``False`` on the JAX path. The JAX path
keeps the padded static shape so the output crosses a ``jax.jit`` boundary cleanly;
strip the infinities outside the jit if needed.

Returns
-------
A ``Grid2DIrregular`` of image-plane coordinates, always numpy-backed even when the
solver uses a JAX backend internally.
A ``Grid2DIrregular`` of image-plane coordinates. NumPy-backed on the default path,
``jax.Array``-backed when ``use_jax=True`` (or ``xp=jnp``).

Notes
-----
Expand All @@ -87,6 +94,20 @@ def solve(
inside a ``jax.jit`` trace, so a plain numpy-backed ``Grid2DIrregular`` is safe
here even when the surrounding analysis uses ``xp=jnp``.
"""
if xp is None:
xp = self._xp

if remove_infinities is None:
remove_infinities = not self.use_jax

# NOTE: pytree registration is the user's responsibility (call
# `autolens.jax.register_tracer_classes(tracer)` once before wrapping
# in @jax.jit). Auto-registering inside solve() doesn't help because
# JAX flattens function arguments at trace time — before entering
# this method — so registration must run before the first jitted
# call. See the `lens_calc.py` workspace guide for the canonical
# JIT-it-yourself pattern.

if os.environ.get("PYAUTO_SMALL_DATASETS") == "1":
return aa.Grid2DIrregular(values=[(1.0, 0.0), (0.0, 1.0)])

Expand Down
30 changes: 30 additions & 0 deletions autolens/point/solver/shape_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
pixel_scale_precision: float,
magnification_threshold=0.1,
neighbor_degree: int = 1,
use_jax: bool = False,
):
"""
Determine the image plane coordinates that are traced to be a source plane coordinate.
Expand All @@ -67,6 +68,12 @@ def __init__(
neighbor_degree
The number of times recursively add neighbors for the triangles that contain
the source plane coordinate.
use_jax
If ``True``, ``.solve()`` defaults to ``xp=jnp`` and ``remove_infinities=False``
(the JAX-static-shape contract), and registers ``Tracer`` plus the concrete
galaxy / profile classes it carries as JAX pytrees on the first call. The
user wraps the call in their own ``@jax.jit`` — see the ``lens_calc.py``
workspace guide for the canonical pattern.
"""
self.y_min = y_min
self.y_max = y_max
Expand All @@ -76,6 +83,18 @@ def __init__(
self.pixel_scale_precision = pixel_scale_precision
self.magnification_threshold = magnification_threshold
self.neighbor_degree = neighbor_degree
self.use_jax = use_jax

@property
def _xp(self):
"""The array module the solver runs against by default. ``jnp`` when
``use_jax=True``, ``np`` otherwise. ``.solve()`` falls back to this when
the caller does not pass ``xp=`` explicitly."""
if self.use_jax:
import jax.numpy as jnp

return jnp
return np

def _initial_triangles(self, xp):
"""
Expand Down Expand Up @@ -107,6 +126,7 @@ def for_grid(
pixel_scale_precision: float,
magnification_threshold=0.1,
neighbor_degree: int = 1,
use_jax: bool = False,
):
"""
Create a solver for a given grid.
Expand All @@ -123,6 +143,8 @@ def for_grid(
The threshold for the magnification under which multiple images are filtered.
neighbor_degree
The number of times recursively add neighbors for the triangles that contain
use_jax
Forwarded to the constructor; see ``__init__``.

Returns
-------
Expand All @@ -142,6 +164,7 @@ def for_grid(
pixel_scale_precision=pixel_scale_precision,
magnification_threshold=magnification_threshold,
neighbor_degree=neighbor_degree,
use_jax=use_jax,
)

@classmethod
Expand All @@ -155,6 +178,7 @@ def for_limits_and_scale(
pixel_scale_precision: float = 0.001,
magnification_threshold=0.1,
neighbor_degree: int = 1,
use_jax: bool = False,
):
"""
Create a solver for an explicit image-plane extent.
Expand All @@ -171,6 +195,8 @@ def for_limits_and_scale(
The threshold for the magnification under which multiple images are filtered.
neighbor_degree
The number of times recursively add neighbors for the triangles that contain
use_jax
Forwarded to the constructor; see ``__init__``.

Returns
-------
Expand All @@ -185,6 +211,7 @@ def for_limits_and_scale(
pixel_scale_precision=pixel_scale_precision,
magnification_threshold=magnification_threshold,
neighbor_degree=neighbor_degree,
use_jax=use_jax,
)

@property
Expand Down Expand Up @@ -384,6 +411,7 @@ def tree_flatten(self):
self.pixel_scale_precision,
self.magnification_threshold,
self.neighbor_degree,
self.use_jax,
)

@classmethod
Expand All @@ -397,6 +425,7 @@ def tree_unflatten(cls, aux_data, children):
pixel_scale_precision,
magnification_threshold,
neighbor_degree,
use_jax,
) = aux_data
return cls(
y_min=y_min,
Expand All @@ -407,6 +436,7 @@ def tree_unflatten(cls, aux_data, children):
pixel_scale_precision=pixel_scale_precision,
magnification_threshold=magnification_threshold,
neighbor_degree=neighbor_degree,
use_jax=use_jax,
)


Expand Down
73 changes: 73 additions & 0 deletions test_autolens/point/triangles/test_use_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Unit tests for the ``PointSolver(use_jax=True)`` constructor wiring.

Per the PyAutoArray dependency-graph rule, library unit tests stay NumPy-only —
cross-xp numerical parity for the actual JAX execution path lives in
``autolens_workspace_test/scripts/point_source/solver_use_jax_parity.py``.
These tests cover only the constructor wiring, the default fallbacks
(``xp=None``, ``remove_infinities=None``), and the pytree-tree-flatten roundtrip.
"""
import numpy as np

import autolens as al


def test_use_jax_defaults_false():
solver = al.PointSolver.for_grid(
grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0),
pixel_scale_precision=0.01,
)
assert solver.use_jax is False
assert solver._xp is np


def test_use_jax_flag_threads_through_for_grid():
solver = al.PointSolver.for_grid(
grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0),
pixel_scale_precision=0.01,
use_jax=True,
)
assert solver.use_jax is True


def test_use_jax_flag_threads_through_for_limits_and_scale():
solver = al.PointSolver.for_limits_and_scale(
y_min=-1.0,
y_max=1.0,
x_min=-1.0,
x_max=1.0,
scale=0.1,
pixel_scale_precision=0.01,
use_jax=True,
)
assert solver.use_jax is True


def test_tree_flatten_roundtrips_use_jax():
solver = al.PointSolver.for_grid(
grid=al.Grid2D.uniform(shape_native=(10, 10), pixel_scales=1.0),
pixel_scale_precision=0.01,
use_jax=True,
)
_children, aux = solver.tree_flatten()
rebuilt = al.PointSolver.tree_unflatten(aux, _children)
assert rebuilt.use_jax is True
assert rebuilt.y_min == solver.y_min
assert rebuilt.scale == solver.scale


def test_solve_default_remove_infinities_numpy(grid):
"""On the NumPy path the solve() default still removes infinities (back-compat)."""
solver = al.PointSolver.for_grid(
grid=grid, pixel_scale_precision=0.5, magnification_threshold=1e-8
)
tracer = al.Tracer(
galaxies=[
al.Galaxy(
redshift=0.5,
mass=al.mp.Isothermal(centre=(0.0, 0.0), einstein_radius=1.0),
)
]
)
result = solver.solve(tracer, source_plane_coordinate=(0.0, 0.0))
# Removing infinities means no row should contain inf on the default numpy path.
assert not np.isinf(np.asarray(result.array)).any()
Loading