Long-form reference for the grid decorators, the xp (NumPy/JAX) backend
pattern, and how autoarray types cross the jax.jit boundary. The per-repo
AGENTS.md files keep only a short summary and link here. This is the single
canonical source for the detail — PyAutoGalaxy and PyAutoLens point at it
rather than re-explaining.
Everything below is grounded in the installed source under
autoarray/, autogalaxy/, and autolens/. Where a class or function is
named, it exists in the current tree.
autoarray/structures/decorators/ contains the output-wrapping decorators
used on all grid-consuming functions. They ensure the type of the output
structure matches the type of the input grid.
Import them as aa.decorators.*. (aa.grid_dec still resolves as a
deprecated alias — autoarray/__init__.py defines
from .structures import decorators as grid_dec # deprecated alias — but
every shipped profile uses aa.decorators.*, so write that form.)
| Decorator | Grid2D input → |
Grid2DIrregular input → |
|---|---|---|
@aa.decorators.to_array |
Array2D |
ArrayIrregular |
@aa.decorators.to_grid |
Grid2D |
Grid2DIrregular |
@aa.decorators.to_vector_yx |
VectorYX2D |
VectorYX2DIrregular |
All three share AbstractMaker (decorators/abstract.py). The decorator:
- Wraps the function in a
wrapper(obj, grid, xp=np, *args, **kwargs)signature. - Instantiates the relevant
*Makerclass with the function, object, grid, andxp. AbstractMaker.resultchecks the grid type and calls the appropriatevia_grid_2d/via_grid_2d_irrmethod to wrap the raw result.
The function body receives the grid as-is and must return a raw array (not an autoarray wrapper). The decorator does the wrapping:
@aa.decorators.to_array
def convergence_2d_from(self, grid, xp=np, **kwargs):
# grid is Grid2D or Grid2DIrregular — access raw values via grid.array[:, 0]
y = grid.array[:, 0]
x = grid.array[:, 1]
return xp.sqrt(y**2 + x**2) # return raw array; decorator wraps itAbstractMaker stores use_jax = xp is not np and exposes _xp (either jnp
or np), but the wrapping step always runs regardless of xp.
Access the raw underlying array with .array:
# Correct — works for both numpy and jax backends
y = grid.array[:, 0]
x = grid.array[:, 1]
# Also works for simple slicing (returns raw array via __getitem__)
y = grid[:, 0]
x = grid[:, 1]Prefer grid.array[:, 0] — after @transform the grid is still an autoarray
object and .array is the safe way to extract the underlying data for both
numpy and jax backends.
@aa.decorators.transform shifts and rotates the input grid to the profile's
reference frame before passing it to the function. It calls
obj.transformed_to_reference_frame_grid_from(grid, xp) (itself decorated with
@to_grid) and passes the result as the grid argument. After transformation
the grid is still an autoarray object; .array still works. Some call sites
pass rotate_back=True (e.g. @aa.decorators.transform(rotate_back=True)).
Decorators apply bottom-up (innermost first). The canonical order for mass/light profile methods is:
@aa.decorators.to_array # outermost: wraps output
@aa.decorators.transform # innermost: transforms grid input
def convergence_2d_from(self, grid, xp=np, **kwargs):
...All data structures inherit from AbstractNDArray (abstract_ndarray.py).
Key subclasses: Array2D, ArrayIrregular, Grid2D, Grid2DIrregular,
VectorYX2D, VectorYX2DIrregular.
AbstractNDArray provides arithmetic operators (__add__, __sub__,
__rsub__, …) so operations between autoarray objects and raw scalars/arrays
return a new autoarray of the same type. The .array property returns the raw
underlying numpy.ndarray or jax.Array:
arr = aa.ArrayIrregular(values=[1.0, 2.0])
arr.array # raw numpy (or jax) array
arr._array # same, internal attributeThe constructor unwraps nested autoarray objects automatically
(while isinstance(array, AbstractNDArray): array = array.array).
The codebase is designed so that NumPy is the default everywhere and JAX is
opt-in. JAX is never imported at module level — only locally inside functions
when explicitly requested. The xp parameter is the single point of control:
xp=np(default throughout) — pure NumPy path, no JAX dependency at runtime.xp=jnp— JAX path;jax/jax.numpyimported locally inside the function.
When adding a new function that should support JAX:
- Default the parameter to
xp=np. - Guard any JAX imports with
if xp is not np:and importjax/jax.numpylocally inside that branch. - Add the NumPy implementation as the default path.
- Add a JAX implementation in the guarded branch (e.g.
jax.jacfwd,jnp.vectorize).
Adding xp=np to a method body and swapping np.* for xp.* is only half
the work. Every nested call inside that body — self.X(), obj.X(), a helper
in convert.py, an inherited @property, or a sibling method — must also
receive xp=xp if it can route to numpy operations on what would otherwise be
JAX tracers. Otherwise the inner call silently defaults to xp=np and fails
when a tracer reaches an np.* op.
Concrete hazards seen in this codebase:
@propertychains that hardcodenp. A property takes no kwargs, so an xp-aware caller must either inline the computation underif xp is not np:or convert the property to a method. Read every@propertyyou call from xp-aware code; if it doesnp.sqrt(...), it is a hazard.- Inherited methods. A method may accept
xpbut a call site forgets to pass it. Within xp-aware functions, grep forself.X(/obj.X(and verifyxp=xpis threaded. convert.pyhelpers. Helpers likeaxis_ratio_and_angle_from,angle_from,multipole_comps_fromall takexp=np; call sites must thread it. They also use Python&on JAX bool tracers, which silently calls__array__()— replace withxp.logical_and.@cached_propertyon traced arrays. Caches a tracer inself.__dict__, which is invalid acrossvmapbatch elements (different batches share the cache). Use plain@propertyfor any value that depends on JAX-traced inputs.
Autoarray types are registered as JAX pytrees (see
abstract_ndarray._register_as_pytree and register_instance_pytree), so a
wrapper can be returned from a jitted function once its class is registered.
Two patterns coexist depending on what the function returns:
Functions intended to be called directly inside jax.jit as the outermost op,
where no wrapper is needed on the JAX path, guard their autoarray wrapping:
def convergence_2d_via_hessian_from(self, grid, xp=np):
convergence = 0.5 * (hessian_yy + hessian_xx)
if xp is np:
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
return convergence # jax: raw jax.ArrayAll LensCalc hessian-derived methods (convergence_2d_via_hessian_from,
shear_yx_2d_via_hessian_from, magnification_2d_via_hessian_from,
magnification_2d_from, tangential_eigen_value_from,
radial_eigen_value_from) use this pattern in
autogalaxy/operate/lens_calc.py and return raw jax.Array on the JAX path.
Intermediate helpers (e.g. deflections_yx_2d_from) do not need the guard
— their autoarray wrappers are consumed by downstream Python before any JIT
boundary.
Functions that must return a real autoarray wrapper (or a structured object built from them) rely on JAX pytree registration:
AbstractNDArrayauto-registers its subclass withjax.tree_utilthe first time an instance is built withxp=jnp, viaautoarray.abstract_ndarray._register_as_pytree.- Higher-level types (
FitImaging,Tracer,DatasetModel) useautoarray.abstract_ndarray.register_instance_pytree(cls, no_flatten=...), which flattens__dict__and carriesno_flattennames throughaux_datafor per-analysis constants (dataset, settings, cosmology). AnalysisImaging._register_fit_imaging_pytreeswires these up whenuse_jax=True, sojax.jit(analysis.fit_from)(instance)returns a realFitImagingwithjax.Arrayleaves.
Library unit tests (test_autoarray/, test_autogalaxy/, test_autolens/)
always run on the NumPy path. No xp=jnp JAX assertion belongs in a library
unit test. A JAX / xp change is validated only by the parity scripts in the
*_workspace_test repos.
jax.jit(fn)(concrete_instance) is NOT a sufficient JAX trace check. A
ModelInstance with concrete float parameters propagates as floats through
np.* ops without raising — an un-threaded xp bug stays hidden. Use
jax.vmap(fitness)(jnp.array(params)) (or Fitness._vmap on autofit's
wrapper) instead: vmap forces tracer propagation through every leaf and exposes
un-threaded xp sites.
When adding a JAX path to an Analysis class, the workspace_test parity script
must include both a jax.jit(analysis.fit_from)(instance) round-trip
and a fitness._vmap(parameters) batch evaluation. PyAutoArray has no own
autoarray_workspace_test; array-level JAX changes are exercised downstream in
autogalaxy_workspace_test/scripts/jax_likelihood_functions/ and the
autolens_workspace_test equivalents.