Skip to content

Commit 602afbc

Browse files
Jammy2211claude
authored andcommitted
feat: add autofit.jax opt-in pytree registration API
Introduces autofit.jax.enable_pytrees() and autofit.jax.register_model(model) so callers can register PyAutoFit prior/model classes (and the user classes referenced by model.cls) with jax.tree_util without paying an eager JAX import at every "import autofit". register_model walks a Model / Collection and builds per-class flatten / unflatten functions that partition constructor arguments: constants (from model.direct_instance_tuples) go into JAX aux_data so they stay as concrete Python values inside jit traces, while prior-derived arguments become children (and therefore tracers). This keeps constants like Galaxy.redshift usable in Python control flow (e.g. sorted(...) by redshift) under JIT. Both functions are no-ops if JAX is not installed. Re-registration of a class that is already registered is tolerated silently. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 73cb28c commit 602afbc

3 files changed

Lines changed: 267 additions & 0 deletions

File tree

autofit/jax/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Opt-in JAX integration helpers for PyAutoFit.
2+
3+
The library defines pytree ``tree_flatten`` / ``tree_unflatten`` methods on
4+
its model and prior classes but, by design, does NOT register them with JAX
5+
at import time. Eager registration would force ``jax.tree_util`` to load on
6+
every ``import autofit`` and reintroduce the heavy JAX import that the
7+
2025-11 cleanup removed.
8+
9+
Call :func:`enable_pytrees` once before crossing a ``jax.jit`` or
10+
``jax.vmap`` boundary with PyAutoFit objects. The function is a no-op if JAX
11+
is not installed.
12+
"""
13+
14+
from .pytrees import enable_pytrees, register_model
15+
16+
__all__ = ["enable_pytrees", "register_model"]

autofit/jax/pytrees.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Lazy JAX pytree registration for PyAutoFit's model and prior classes.
2+
3+
The classes themselves already carry the necessary ``tree_flatten`` /
4+
``tree_unflatten`` methods. This module simply registers them with
5+
``jax.tree_util`` on demand, via the lazy autoconf wrapper, so callers can
6+
pass ``Model`` / ``Collection`` / ``ModelInstance`` / prior instances
7+
through ``jax.jit`` and ``jax.vmap`` directly.
8+
"""
9+
10+
from autoconf.jax_wrapper import register_pytree_node, register_pytree_node_class
11+
12+
_ENABLED = False
13+
_REGISTERED_INSTANCE_CLASSES: set = set()
14+
15+
16+
def enable_pytrees() -> bool:
17+
"""Register PyAutoFit model and prior classes as JAX pytree nodes.
18+
19+
Returns ``True`` if registration was performed, ``False`` if JAX is not
20+
installed (in which case the call is a silent no-op). Safe to call more
21+
than once: subsequent calls return ``True`` without re-registering.
22+
"""
23+
global _ENABLED
24+
if _ENABLED:
25+
return True
26+
27+
try:
28+
import jax # noqa: F401
29+
except ImportError:
30+
return False
31+
32+
import autofit as af
33+
from autofit.mapper.prior_model.prior_model import Model
34+
from autofit.mapper.prior_model.collection import Collection
35+
36+
for cls in (
37+
af.GaussianPrior,
38+
af.UniformPrior,
39+
af.LogGaussianPrior,
40+
af.LogUniformPrior,
41+
af.TruncatedGaussianPrior,
42+
Model,
43+
Collection,
44+
af.ModelInstance,
45+
):
46+
try:
47+
register_pytree_node_class(cls)
48+
except ValueError:
49+
# Already registered (e.g. by another caller or a test fixture).
50+
# Re-registration is a JAX error but harmless for our purposes.
51+
pass
52+
53+
_ENABLED = True
54+
return True
55+
56+
57+
def register_model(model) -> bool:
58+
"""Register every concrete ``model.cls`` in ``model`` as a JAX pytree node.
59+
60+
PyAutoFit's ``Model.instance_flatten`` / ``instance_unflatten`` produce
61+
flatten/unflatten functions for instances of the user-defined class
62+
referenced by ``Model.cls`` (e.g. a ``Galaxy`` or ``Gaussian``). For JAX
63+
to recurse into those instances rather than treat them as opaque leaves,
64+
the user class itself must be registered with ``jax.tree_util``.
65+
66+
This walks ``model`` (a ``Model`` or ``Collection``) and registers each
67+
``cls`` it finds. Re-registering the same class is a silent no-op. Returns
68+
``True`` if registration ran (or was already complete), ``False`` if JAX
69+
is missing.
70+
"""
71+
if not enable_pytrees():
72+
return False
73+
74+
from autofit.mapper.prior_model.prior_model import Model
75+
from autofit.mapper.prior_model.collection import Collection
76+
77+
def _walk(node):
78+
if isinstance(node, Model):
79+
cls = node.cls
80+
if cls not in _REGISTERED_INSTANCE_CLASSES:
81+
flatten, unflatten = _build_instance_pytree_funcs(node)
82+
try:
83+
register_pytree_node(cls, flatten, unflatten)
84+
except ValueError:
85+
# Already registered elsewhere — keep going.
86+
pass
87+
_REGISTERED_INSTANCE_CLASSES.add(cls)
88+
for _, sub in node.direct_prior_model_tuples:
89+
_walk(sub)
90+
elif isinstance(node, Collection):
91+
for _, sub in node.direct_prior_model_tuples:
92+
_walk(sub)
93+
94+
_walk(model)
95+
return True
96+
97+
98+
def _build_instance_pytree_funcs(model):
99+
"""Build flatten/unflatten functions for instances of ``model.cls``.
100+
101+
Constants from the original model definition (e.g. ``Galaxy(redshift=0.5)``)
102+
are placed in the JAX ``aux_data`` so they remain concrete Python values
103+
inside a ``jax.jit`` trace. Only prior-derived constructor arguments are
104+
placed in ``children`` (and therefore become JAX tracers).
105+
106+
This is critical for code that uses constants for control flow — e.g.
107+
``sorted(galaxies, key=lambda g: g.redshift)`` in ``Tracer`` — which would
108+
otherwise raise ``TracerBoolConversionError`` under JIT.
109+
"""
110+
constructor_args = list(model.constructor_argument_names)
111+
constant_arg_names = [
112+
name for name in constructor_args if name in dict(model.direct_instance_tuples)
113+
]
114+
constant_values = {
115+
name: dict(model.direct_instance_tuples)[name] for name in constant_arg_names
116+
}
117+
dynamic_arg_names = [
118+
name for name in constructor_args if name not in constant_arg_names
119+
]
120+
121+
def flatten(instance):
122+
attribute_names = [
123+
name
124+
for name in model.direct_argument_names
125+
if hasattr(instance, name) and name not in constructor_args
126+
]
127+
children = (
128+
[getattr(instance, name) for name in dynamic_arg_names],
129+
[getattr(instance, name) for name in attribute_names],
130+
)
131+
aux = (
132+
tuple(dynamic_arg_names),
133+
tuple(constant_arg_names),
134+
tuple(constant_values[n] for n in constant_arg_names),
135+
tuple(attribute_names),
136+
)
137+
return children, aux
138+
139+
def unflatten(aux, children):
140+
dyn_names, const_names, const_vals, attr_names = aux
141+
dyn_vals, attr_vals = children
142+
kwargs = dict(zip(dyn_names, dyn_vals))
143+
kwargs.update(zip(const_names, const_vals))
144+
ordered = [kwargs[name] for name in constructor_args]
145+
instance = model.cls(*ordered)
146+
for name, value in zip(attr_names, attr_vals):
147+
setattr(instance, name, value)
148+
return instance
149+
150+
return flatten, unflatten
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for ``autofit.jax.enable_pytrees`` / ``register_model``."""
2+
import pytest
3+
4+
jax = pytest.importorskip("jax")
5+
jnp = pytest.importorskip("jax.numpy")
6+
7+
import autofit as af
8+
from autofit.jax import enable_pytrees, register_model
9+
10+
11+
@pytest.fixture(name="model")
12+
def make_model():
13+
return af.Model(
14+
af.ex.Gaussian,
15+
centre=af.GaussianPrior(mean=1.0, sigma=1.0),
16+
normalization=af.GaussianPrior(mean=2.0, sigma=1.0),
17+
sigma=af.GaussianPrior(mean=3.0, sigma=1.0),
18+
)
19+
20+
21+
def test_enable_pytrees_returns_true_when_jax_available():
22+
assert enable_pytrees() is True
23+
24+
25+
def test_register_model_round_trip(model):
26+
register_model(model)
27+
instance = model.instance_from_prior_medians()
28+
29+
leaves, treedef = jax.tree_util.tree_flatten(instance)
30+
rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
31+
32+
assert isinstance(rebuilt, af.ex.Gaussian)
33+
assert float(rebuilt.centre) == float(instance.centre)
34+
assert float(rebuilt.normalization) == float(instance.normalization)
35+
assert float(rebuilt.sigma) == float(instance.sigma)
36+
37+
38+
def test_register_model_works_under_jit(model):
39+
register_model(model)
40+
instance = model.instance_from_prior_medians()
41+
42+
@jax.jit
43+
def total(inst):
44+
return inst.centre + inst.normalization + inst.sigma
45+
46+
assert float(total(instance)) == pytest.approx(
47+
instance.centre + instance.normalization + instance.sigma
48+
)
49+
50+
51+
def test_register_model_keeps_constants_static():
52+
"""Constants from the model definition must NOT be traced under JIT.
53+
54+
Galaxy.redshift is a common case: it sits in the constructor signature
55+
but gets a fixed value via ``af.Constant``. If it became a JAX tracer,
56+
code paths like ``sorted(galaxies, key=lambda g: g.redshift)`` would
57+
blow up with ``TracerBoolConversionError``. The ``aux_data`` partition
58+
in ``register_model`` keeps it as a Python float.
59+
"""
60+
class Holder:
61+
def __init__(self, redshift, scale):
62+
self.redshift = redshift
63+
self.scale = scale
64+
65+
model = af.Model(
66+
Holder,
67+
redshift=0.5,
68+
scale=af.GaussianPrior(mean=1.0, sigma=1.0),
69+
)
70+
register_model(model)
71+
instance = model.instance_from_prior_medians()
72+
assert isinstance(instance.redshift, float)
73+
74+
@jax.jit
75+
def use_redshift_for_control_flow(inst):
76+
# `if` on a traced value would raise TracerBoolConversionError;
77+
# this only works if redshift is kept static via aux_data.
78+
if inst.redshift > 0:
79+
return inst.scale * 2
80+
return inst.scale
81+
82+
result = use_redshift_for_control_flow(instance)
83+
assert float(result) == pytest.approx(2.0 * instance.scale)
84+
85+
86+
def test_enable_pytrees_idempotent():
87+
assert enable_pytrees() is True
88+
assert enable_pytrees() is True
89+
90+
91+
def test_collection_round_trip(model):
92+
register_model(model)
93+
collection = af.Collection(g1=model, g2=model)
94+
register_model(collection)
95+
96+
instance = collection.instance_from_prior_medians()
97+
leaves, treedef = jax.tree_util.tree_flatten(instance)
98+
rebuilt = jax.tree_util.tree_unflatten(treedef, leaves)
99+
100+
assert float(rebuilt.g1.centre) == float(instance.g1.centre)
101+
assert float(rebuilt.g2.sigma) == float(instance.g2.sigma)

0 commit comments

Comments
 (0)