diff --git a/scripts/jax_assertions/__init__.py b/scripts/jax_assertions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/scripts/jax_assertions/enable_pytrees.py b/scripts/jax_assertions/enable_pytrees.py new file mode 100644 index 0000000..cf54337 --- /dev/null +++ b/scripts/jax_assertions/enable_pytrees.py @@ -0,0 +1,224 @@ +""" +Jax Assertions: enable_pytrees() / register_model() Public API +============================================================== + +Verifies the public ``autofit.jax.enable_pytrees`` / ``register_model`` path +that PyAutoGalaxy and PyAutoLens use to make their ``Model`` / +``ModelInstance`` types JAX-traceable. Covers: + +- Model roundtrip via the public API +- Compilation under ``jax.jit`` +- Constant attributes (positional float, ``**kwargs`` concrete object) stay + in ``aux_data`` and are NOT traced — protects ``galaxy.redshift`` sorting + and ``isinstance(x, Pixelization)`` checks under JIT +- ``TuplePrior`` paired-prior gradients flow through ``jax.value_and_grad`` +- ``enable_pytrees()`` is idempotent +- ``Collection`` round-trip preserves nested model identity + +Previously: ``test_autofit/jax/test_enable_pytrees.py``. +""" + +import jax +import jax.numpy as jnp +import numpy as np +import numpy.testing as npt + +import autofit as af +from autofit.jax import enable_pytrees, register_model + + +def make_model(): + return af.Model( + af.ex.Gaussian, + centre=af.GaussianPrior(mean=1.0, sigma=1.0), + normalization=af.GaussianPrior(mean=2.0, sigma=1.0), + sigma=af.GaussianPrior(mean=3.0, sigma=1.0), + ) + + +""" +__enable_pytrees() returns True When jax Available__ +""" +assert enable_pytrees() is True + +""" +__Model Round-Trip via Public API__ +""" +model = make_model() +register_model(model) +instance = model.instance_from_prior_medians() + +leaves, treedef = jax.tree_util.tree_flatten(instance) +rebuilt = jax.tree_util.tree_unflatten(treedef, leaves) + +assert isinstance(rebuilt, af.ex.Gaussian) +assert float(rebuilt.centre) == float(instance.centre) +assert float(rebuilt.normalization) == float(instance.normalization) +assert float(rebuilt.sigma) == float(instance.sigma) + +""" +__Model Computable Under jax.jit__ +""" +model = make_model() +register_model(model) +instance = model.instance_from_prior_medians() + + +@jax.jit +def total(inst): + return inst.centre + inst.normalization + inst.sigma + + +npt.assert_allclose( + float(total(instance)), + instance.centre + instance.normalization + instance.sigma, +) + +""" +__Positional Constant Stays Static (galaxy.redshift pattern)__ + +Constants from the model definition must NOT be traced under JIT. +``Galaxy.redshift`` is a common case: it sits in the constructor signature +but gets a fixed value via ``af.Constant``. If it became a JAX tracer, code +paths like ``sorted(galaxies, key=lambda g: g.redshift)`` would blow up with +``TracerBoolConversionError``. +""" + + +class Holder: + def __init__(self, redshift, scale): + self.redshift = redshift + self.scale = scale + + +model = af.Model( + Holder, + redshift=0.5, + scale=af.GaussianPrior(mean=1.0, sigma=1.0), +) +register_model(model) +instance = model.instance_from_prior_medians() +assert isinstance(instance.redshift, float) + + +@jax.jit +def use_redshift_for_control_flow(inst): + if inst.redshift > 0: + return inst.scale * 2 + return inst.scale + + +result = use_redshift_for_control_flow(instance) +npt.assert_allclose(float(result), 2.0 * instance.scale) + +""" +__Kwarg Constants Stay Static (Pixelization-as-kwarg pattern)__ + +Concrete objects passed as ``**kwargs`` must stay in ``aux_data``, not +``children``. ``Galaxy.__init__(self, redshift, **kwargs)`` stores every +kwarg via ``setattr``. A ``Pixelization`` passed as a kwarg used to be +routed to ``children`` and become a JAX tracer, so downstream +``isinstance(x, Pixelization)`` checks returned False. +""" + + +class Marker: + pass + + +class KwargHolder: + def __init__(self, redshift, **kwargs): + self.redshift = redshift + for k, v in kwargs.items(): + setattr(self, k, v) + + +marker = Marker() +model = af.Model( + KwargHolder, + redshift=0.5, + marker=marker, + scale=af.GaussianPrior(mean=1.0, sigma=1.0), +) +register_model(model) +instance = model.instance_from_prior_medians() +assert instance.marker is marker + + +@jax.jit +def use_marker_isinstance(inst): + if isinstance(inst.marker, Marker): + return inst.scale * 2 + return inst.scale + + +result = use_marker_isinstance(instance) +npt.assert_allclose(float(result), 2.0 * instance.scale) + +""" +__TuplePrior Paired Priors Flow Gradients Through jax.value_and_grad__ + +``TuplePrior``-backed attributes (``centre=(x, y)``, ``ell_comps=(e1, e2)``) +must be routed into JAX children so gradients flow. Prior to the fix, +``TuplePrior`` failed the ``(Prior, AbstractPriorModel)`` isinstance check +in ``register_model``, so the resolved tuple was frozen in ``aux_data`` and +``jax.value_and_grad`` returned gradients only for the non-tuple attributes. +""" + + +class Twin: + def __init__(self, centre, amplitude): + self.centre = centre + self.amplitude = amplitude + + +model = af.Model( + Twin, + centre=af.TuplePrior( + centre_0=af.GaussianPrior(mean=0.5, sigma=1.0), + centre_1=af.GaussianPrior(mean=-0.5, sigma=1.0), + ), + amplitude=af.GaussianPrior(mean=1.0, sigma=1.0), +) +register_model(model) +instance = model.instance_from_prior_medians() +params_tree = jax.tree_util.tree_map(jnp.asarray, instance) + +leaves = jax.tree_util.tree_leaves(params_tree) +assert len(leaves) == 3 # centre[0], centre[1], amplitude + + +def loss(inst): + cx, cy = inst.centre + return cx * cx + cy * cy + inst.amplitude + + +_, grad = jax.value_and_grad(loss)(params_tree) +flat_grad = jnp.concatenate( + [jnp.asarray(l).ravel() for l in jax.tree_util.tree_leaves(grad)] +) +assert jnp.all(jnp.isfinite(flat_grad)) +assert flat_grad.size == 3 + +""" +__enable_pytrees() Is Idempotent__ +""" +assert enable_pytrees() is True +assert enable_pytrees() is True + +""" +__Collection Round-Trip Preserves Nested Model Identity__ +""" +model = make_model() +register_model(model) +collection = af.Collection(g1=model, g2=model) +register_model(collection) + +instance = collection.instance_from_prior_medians() +leaves, treedef = jax.tree_util.tree_flatten(instance) +rebuilt = jax.tree_util.tree_unflatten(treedef, leaves) + +npt.assert_allclose(float(rebuilt.g1.centre), float(instance.g1.centre)) +npt.assert_allclose(float(rebuilt.g2.sigma), float(instance.g2.sigma)) + +print("enable_pytrees: all assertions passed") diff --git a/scripts/graphical/jax_assertions.py b/scripts/jax_assertions/fitness_dispatch.py similarity index 98% rename from scripts/graphical/jax_assertions.py rename to scripts/jax_assertions/fitness_dispatch.py index d270670..a0a1a48 100644 --- a/scripts/graphical/jax_assertions.py +++ b/scripts/jax_assertions/fitness_dispatch.py @@ -141,4 +141,4 @@ def assert_array_optimisation_returns_jnp_instance(): assert_pickle_strips_jax_cached_attrs() assert_fit_for_visualization_dispatches_through_jit_when_flag_set() assert_array_optimisation_returns_jnp_instance() - print("jax_assertions: all assertions passed") + print("fitness_dispatch: all assertions passed") diff --git a/scripts/jax_assertions/nested.py b/scripts/jax_assertions/nested.py new file mode 100644 index 0000000..9ea67e2 --- /dev/null +++ b/scripts/jax_assertions/nested.py @@ -0,0 +1,275 @@ +""" +Jax Assertions: Nested Tree Utils vs jax.tree_util +================================================== + +Verifies that ``autofit.graphical.utils.nested_*`` (autofit's recursive tree +walking utilities) agree with ``jax.tree_util`` as a reference implementation +across: + +- ``nested_get`` / ``nested_set`` indexing into dict/tuple/list trees +- ``nested_zip`` and ``nested_filter`` ordering matches ``tree_flatten`` +- ``nested_items`` paths match ``tree_flatten_with_path`` (after key + normalization) +- ``nested_map`` / ``nested_iter`` propagating across heterogeneous trees +- ``nested_update`` semantics with NamedTuples preserved + +The ``jax_*`` helpers below convert between jax's typed key path +(``SequenceKey`` / ``DictKey`` / ``GetAttrKey``) and autofit's plain tuple +key path so the two utility families can be compared directly. + +Previously: ``test_autofit/graphical/functionality/test_nested.py``. +""" + +import collections + +import jax.tree_util as tree_util + +from autofit.graphical import utils + +NTuple = collections.namedtuple("NTuple", "first, last") + + +def jax_nested_zip(tree, *rest): + leaves, treedef = tree_util.tree_flatten(tree) + return zip(leaves, *(treedef.flatten_up_to(r) for r in rest)) + + +def jax_key_to_val(key): + if isinstance(key, tree_util.SequenceKey): + return key.idx + elif isinstance(key, (tree_util.DictKey, tree_util.FlattenedIndexKey)): + return key.key + elif isinstance(key, tree_util.GetAttrKey): + return key.name + return key + + +def jax_path_to_key(path): + return tuple(map(jax_key_to_val, path)) + + +""" +__nested_get__ +""" +obj = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} + +assert utils.nested_get(obj, ("b",)) == 2 +assert utils.nested_get(obj, ("c", "a")) == 1 +assert utils.nested_get(obj, ("d", 0)) == 3 +assert utils.nested_get(obj, ("d", 1, "e", 1)) == 5 + +""" +__nested_set__ +""" +obj = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} + +utils.nested_set(obj, ("b",), 3) +assert utils.nested_get(obj, ("b",)) + +utils.nested_set(obj, ("c", "a"), 2) +assert utils.nested_get(obj, ("c", "a")) == 2 + +utils.nested_set(obj, ("d", 1, "e", 1), 6) +assert utils.nested_get(obj, ("d", 1, "e", 1)) == 6 + +# Setting into an immutable tuple member must raise. +try: + utils.nested_set(obj, ("d", 0), 4) +except TypeError: + pass +else: + raise AssertionError("nested_set into tuple should have raised TypeError") + +""" +__nested_zip and nested_filter Ordering Matches tree_flatten__ + +Verify autofit's nested_zip walks the tree in the same order as jax's +tree_flatten, and that nested_filter agrees with raw equality, both for +plain dict/tuple/list trees AND trees containing NamedTuples. +""" +obj1 = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} +obj2 = {"a": 1, "b": 2, "d": (3, {"e": [4, 5]}), "c": {"b": 2, "a": 1}} + +assert all(v1 == v2 for (v1, v2) in utils.nested_zip(obj1, obj2)) +assert all(utils.nested_filter(lambda x, y: x == y, obj1, obj2)) +assert list(utils.nested_zip(obj1)) == list(utils.nested_zip(obj2)) +assert list(utils.nested_zip(obj1, obj2)) == list(jax_nested_zip(obj1, obj2)) + +obj1 = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": NTuple(4, 5)})} +obj2 = {"a": 1, "b": 2, "d": (3, {"e": NTuple(4, 5)}), "c": {"b": 2, "a": 1}} + +assert all(v1 == v2 for (v1, v2) in utils.nested_zip(obj1, obj2)) +assert all(utils.nested_filter(lambda x, y: x == y, obj1, obj2)) +assert list(utils.nested_zip(obj1)) == list(utils.nested_zip(obj2)) +assert list(utils.nested_zip(obj1, obj2)) == list(jax_nested_zip(obj1, obj2)) + +""" +__nested_items Paths Match tree_flatten_with_path__ +""" +obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} + +for (k1, v1), (p2, v2) in zip( + utils.nested_items(obj1), tree_util.tree_flatten_with_path(obj1)[0] +): + assert k1 == jax_path_to_key(p2) + assert v1 == v2 + +""" +__nested_filter By Predicate__ +""" +obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} +assert list(utils.nested_filter(lambda x: x % 2 == 0, obj1)) == [(2,), (4,), (2,)] + +obj1 = {"b": 2, "a": 1, "c": (3, {"e": [4, 5]}), "d": {"b": 2, "a": 1}} +assert list(utils.nested_filter(lambda x: x % 2 == 0, obj1)) == [(2,), (4,), (2,)] + +""" +__nested_map Across Heterogeneous Trees__ +""" +obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} +obj2 = {"a": 2, "b": 4, "c": (6, {"e": [8, 10]}), "d": {"a": 2, "b": 4}} +obj12 = utils.nested_map(lambda x: x * 2, obj1) +assert obj12 == obj2 + +obj3 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": (4, 5)})} +obj4 = {"a": 2, "b": 4, "c": (6, {"e": (8, 10)}), "d": {"a": 2, "b": 4}} +obj32 = utils.nested_map(lambda x: x * 2, obj3) +assert obj32 == obj4 + +obj5 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": NTuple(4, 5)})} +obj6 = {"a": 2, "b": 4, "c": (6, {"e": NTuple(8, 10)}), "d": {"a": 2, "b": 4}} +obj52 = utils.nested_map(lambda x: x * 2, obj5) +assert obj52 == obj6 == obj4 + +assert obj32 != obj2 +assert obj52 != obj2 + +assert all( + utils.nested_iter( + utils.nested_map(lambda a, b, c: a == b == c, obj1, obj3, obj5) + ) +) +assert all( + utils.nested_iter( + utils.nested_map(lambda a, b, c: a == b == c, obj2, obj4, obj6) + ) +) +assert all(map(lambda x: x[0] == x[1] == x[2], utils.nested_zip(obj1, obj3, obj5))) +assert all( + map(lambda x: x[0] == x[1] == x[2], utils.nested_zip(obj2, obj32, obj52)) +) + +""" +__nested_update Preserves NamedTuple Type__ +""" +assert utils.nested_update([1, (2, 3), [3, 2, {1, 2}]], {2: "a"}) == [ + 1, + ("a", 3), + [3, "a", {1, "a"}], +] +assert utils.nested_update([1, NTuple(2, 3), [3, 2, {1, 2}]], {2: "a"}) == [ + 1, + ("a", 3), + [3, "a", {1, "a"}], +] +assert isinstance( + utils.nested_update([1, NTuple(2, 3), [3, 2, {1, 2}]], {2: "a"})[1], NTuple +) +assert utils.nested_update([{2: 2}], {2: "a"}) == [{2: "a"}] + +""" +__nested_items Cross-Tree Lookup__ + +(Original test file had two functions named `test_nested_items` — pytest +ran both due to last-defined wins; we run the second variant since it's +the more comprehensive one.) +""" +obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} +obj2 = {"a": 2, "b": 4, "c": (6, {"e": [8, 10]}), "d": {"a": 2, "b": 4}} +obj3 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": (4, 5)})} +obj4 = {"a": 2, "b": 4, "c": (6, {"e": (8, 10)}), "d": {"a": 2, "b": 4}} +obj5 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": NTuple(4, 5)})} +obj6 = {"a": 2, "b": 4, "c": (6, {"e": NTuple(8, 10)}), "d": {"a": 2, "b": 4}} + +for path, val in utils.nested_items(obj1): + assert ( + utils.nested_getitem(obj2, path) + == utils.nested_getitem(obj4, path) + == val * 2 + ) + +for path, val in utils.nested_items(obj3): + assert ( + utils.nested_getitem(obj4, path) + == utils.nested_getitem(obj6, path) + == val * 2 + ) + +for path, val in utils.nested_items(obj5): + assert ( + utils.nested_getitem(obj6, path) + == utils.nested_getitem(obj2, path) + == val * 2 + ) + +assert list(utils.nested_items([NTuple(1, 2), {2: 5, 1: 3}])) == [ + ((0, 0), 1), + ((0, 1), 2), + ((1, 1), 3), + ((1, 2), 5), +] + +assert list(utils.nested_items([1, (2, 3), [3, {"a": 1, "b": 2}]])) == list( + utils.nested_items([1, (2, 3), [3, {"b": 2, "a": 1}]]) +) +assert list( + utils.nested_items( + [ + 1, + (2, 3), + [ + 3, + { + "b": 2, + "a": 1, + }, + ], + ] + ) +) == list(utils.nested_items([1, (2, 3), [3, {"b": 2, "a": 1}]])) + +obj1 = [1, (2, 3), [3, {"b": 2, "a": 1}]] +obj2 = [1, (2, 3), [3, {"a": 1, "b": 2}]] +obj3 = [1, NTuple(2, 3), [3, {"a": 1, "b": 2}]] + +if hasattr(tree_util, "tree_flatten_with_path"): + jax_flat = tree_util.tree_flatten_with_path(obj1)[0] + af_flat = utils.nested_items(obj2) + + for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat): + jkey = jax_path_to_key(jpath) + assert jkey == akey + assert jval == aval + assert ( + utils.nested_get(obj2, jkey) + == utils.nested_get(obj1, jkey) + == utils.nested_get(obj2, akey) + == utils.nested_get(obj1, akey) + ) + + jax_flat = tree_util.tree_flatten_with_path(obj2)[0] + af_flat = utils.nested_items(obj3) + for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat): + jkey = jax_path_to_key(jpath) + assert jkey == akey + assert jval == aval + assert ( + utils.nested_get(obj2, jkey) + == utils.nested_get(obj1, jkey) + == utils.nested_get(obj2, akey) + == utils.nested_get(obj1, akey) + == utils.nested_get(obj3, jkey) + == utils.nested_get(obj3, akey) + ) + +print("nested: all assertions passed") diff --git a/scripts/jax_assertions/pytrees.py b/scripts/jax_assertions/pytrees.py new file mode 100644 index 0000000..bd7049d --- /dev/null +++ b/scripts/jax_assertions/pytrees.py @@ -0,0 +1,114 @@ +""" +Jax Assertions: Manual Pytree Roundtrip +======================================= + +Verifies that ``jax.tree_util.register_pytree_node_class`` correctly registers +PyAutoFit's ``Prior``, ``Model``, ``Collection``, and ``ModelInstance`` types +so that flatten/unflatten roundtrips preserve identity (``id``), bounds, and +nested structure. + +The test pattern manually pulls the ``flatten_func`` / ``unflatten_func`` out +of ``jax._src.tree_util._registry`` and round-trips an instance — testing +the legacy direct-registration path used before +``autofit.jax.enable_pytrees()`` existed (see ``enable_pytrees.py`` for the +public API path). + +Previously: ``test_autofit/jax/test_pytrees.py``. +""" + +import jax +import jax.numpy as jnp +import numpy as np + +import autofit as af +from autofit import UniformPrior +from jax.tree_util import register_pytree_node_class + +# Manual pytree registration for the legacy direct path. +UniformPrior = register_pytree_node_class(UniformPrior) +GaussianPrior = register_pytree_node_class(af.GaussianPrior) +TruncatedGaussianPrior = register_pytree_node_class(af.TruncatedGaussianPrior) +Collection = register_pytree_node_class(af.Collection) +Model = register_pytree_node_class(af.Model) +ModelInstance = register_pytree_node_class(af.ModelInstance) + + +def recreate(o): + """jax-pytree-roundtrip helper inlined from the original test fixture.""" + flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] + children, aux_data = flatten_func(o) + return unflatten_func(aux_data, children) + + +# Original fixture monkeypatched af.example.model.np = jnp via autouse. +# Replicate globally — the active tests below don't actually call gaussian.f, +# but keep the patch to match original semantics. +af.example.model.np = jnp + + +model = Model( + af.ex.Gaussian, + centre=af.GaussianPrior(mean=1.0, sigma=1.0), + normalization=af.GaussianPrior(mean=1.0, sigma=1.0), + sigma=af.GaussianPrior(mean=1.0, sigma=1.0), +) + +""" +__TruncatedGaussianPrior Roundtrip__ +""" +prior = TruncatedGaussianPrior(mean=1.0, sigma=1.0) +new = recreate(prior) + +assert new.mean == prior.mean +assert new.sigma == prior.sigma +assert new.id == prior.id +assert new.lower_limit == prior.lower_limit +assert new.upper_limit == prior.upper_limit + +""" +__Model Roundtrip__ +""" +new = recreate(model) +assert new.cls == af.ex.Gaussian + +centre = new.centre +assert centre.mean == model.centre.mean +assert centre.sigma == model.centre.sigma +assert centre.id == model.centre.id + +""" +__UniformPrior Roundtrip__ +""" +prior = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) +new = recreate(prior) + +assert new.lower_limit == prior.lower_limit +assert new.upper_limit == prior.upper_limit +assert new.id == prior.id + +""" +__ModelInstance Roundtrip__ +""" +collection = Collection(gaussian=model) +instance = collection.instance_from_prior_medians() +new = recreate(instance) + +assert isinstance(new, ModelInstance) +assert isinstance(new.gaussian, af.ex.Gaussian) + +""" +__Collection Roundtrip__ +""" +collection = Collection(gaussian=model) +new = recreate(collection) + +assert isinstance(new, Collection) +assert isinstance(new.gaussian, Model) +assert new.gaussian.cls == af.ex.Gaussian + +centre = new.gaussian.centre +assert centre.mean == model.centre.mean +assert centre.sigma == model.centre.sigma +assert centre.id == model.centre.id + +print("pytrees: all assertions passed")