diff --git a/test_autofit/conftest.py b/test_autofit/conftest.py index 03895c6a1..2eeb884ae 100644 --- a/test_autofit/conftest.py +++ b/test_autofit/conftest.py @@ -1,4 +1,3 @@ -import importlib.util import multiprocessing import os import shutil @@ -16,16 +15,6 @@ from autofit.database.model import sa from autofit.non_linear.search import abstract_search -# Skip JAX-only tests when jax isn't installed. find_spec checks availability -# WITHOUT importing jax, so this conftest stays numpy-only per the -# "library unit tests stay numpy-only" rule. -collect_ignore_glob = [] -if importlib.util.find_spec("jax") is None: - collect_ignore_glob = [ - "jax/*.py", - "graphical/functionality/test_jacobians.py", - ] - if sys.platform == "darwin": multiprocessing.set_start_method("fork") diff --git a/test_autofit/graphical/functionality/test_jacobians.py b/test_autofit/graphical/functionality/test_jacobians.py deleted file mode 100644 index 30c0b3dda..000000000 --- a/test_autofit/graphical/functionality/test_jacobians.py +++ /dev/null @@ -1,144 +0,0 @@ -from itertools import combinations -import numpy as np -import pytest - -# Belt-and-suspenders: the top-level conftest also ignores this file when -# jax isn't installed; importorskip keeps the file self-contained. -jax = pytest.importorskip("jax") - -from autofit.mapper.variable import variables -from autofit.graphical.factor_graphs import ( - Factor, - FactorValue, -) - - -# def test_jacobian_equiv(): -# -# def linear(x, a, b, c): -# z = x.dot(a) + b -# return (z**2).sum(), z -# -# x_, a_, b_, c_, z_ = variables("x, a, b, c, z") -# x = np.arange(10.0).reshape(5, 2) -# a = np.arange(2.0).reshape(2, 1) -# b = np.ones(1) -# c = -1.0 -# -# factors = [ -# Factor( -# linear, -# x_, -# a_, -# b_, -# c_, -# factor_out=(FactorValue, z_), -# numerical_jacobian=False, -# ), -# Factor( -# linear, -# x_, -# a_, -# b_, -# c_, -# factor_out=(FactorValue, z_), -# numerical_jacobian=False, -# jacfwd=False, -# ), -# Factor( -# linear, -# x_, -# a_, -# b_, -# c_, -# factor_out=(FactorValue, z_), -# numerical_jacobian=False, -# vjp=True, -# ), -# Factor( -# linear, -# x_, -# a_, -# b_, -# c_, -# factor_out=(FactorValue, z_), -# numerical_jacobian=True, -# ), -# ] -# -# values = {x_: x, a_: a, b_: b, c_: c} -# outputs = [factor.func_jacobian(values) for factor in factors] -# -# tol = pytest.approx(0, abs=1e-4) -# pairs = combinations(outputs, 2) -# g0 = FactorValue(1.0, {z_: np.ones((5, 1))}) -# for (val1, jac1), (val2, jac2) in pairs: -# assert val1 == val2 -# -# # test with different ways of calculating gradients -# grad1, grad2 = jac1.grad(g0), jac2.grad(g0) -# assert (grad1 - grad2).norm() == tol -# grad1 = g0.to_dict() * jac1 -# assert (grad1 - grad2).norm() == tol -# grad2 = g0.to_dict() * jac2 -# assert (grad1 - grad2).norm() == tol -# -# grad1, grad2 = jac1.grad(val1), jac2.grad(val2) -# assert (grad1 - grad2).norm() == tol -# -# # test getting gradient with no args -# assert (jac1.grad() - jac2.grad()).norm() == tol -# -# -# def test_jac_model(): -# -# def linear(x, a, b): -# z = x.dot(a) + b -# return (z**2).sum(), z -# -# def likelihood(y, z): -# return ((y - z) ** 2).sum() -# -# def combined(x, y, a, b): -# like, z = linear(x, a, b) -# return like + likelihood(y, z) -# -# x_, a_, b_, y_, z_ = variables("x, a, b, y, z") -# x = np.arange(10.0).reshape(5, 2) -# a = np.arange(2.0).reshape(2, 1) -# b = np.ones(1) -# y = np.arange(0.0, 10.0, 2).reshape(5, 1) -# values = {x_: x, y_: y, a_: a, b_: b} -# linear_factor = Factor(linear, x_, a_, b_, factor_out=(FactorValue, z_), vjp=True) -# like_factor = Factor(likelihood, y_, z_, vjp=True) -# full_factor = Factor(combined, x_, y_, a_, b_, vjp=True) -# model_factor = like_factor * linear_factor -# -# x = np.arange(10.0).reshape(5, 2) -# a = np.arange(2.0).reshape(2, 1) -# b = np.ones(1) -# y = np.arange(0.0, 10.0, 2).reshape(5, 1) -# values = {x_: x, y_: y, a_: a, b_: b} -# -# # Fully working problem -# fval, jac = full_factor.func_jacobian(values) -# grad = jac.grad() -# -# model_val, model_jac = model_factor.func_jacobian(values) -# model_grad = model_jac.grad() -# -# linear_val, linear_jac = linear_factor.func_jacobian(values) -# like_val, like_jac = like_factor.func_jacobian( -# {**values, **linear_val.deterministic_values} -# ) -# combined_val = like_val + linear_val -# -# # Manually back propagate -# combined_grads = linear_jac.grad(like_jac.grad()) -# -# vals = (fval, model_val, combined_val) -# grads = (grad, model_grad, combined_grads) -# pairs = combinations(zip(vals, grads), 2) -# for (val1, grad1), (val2, grad2) in pairs: -# assert val1 == val2 -# assert (grad1 - grad2).norm() == pytest.approx(0, 1e-6) diff --git a/test_autofit/graphical/functionality/test_nested.py b/test_autofit/graphical/functionality/test_nested.py deleted file mode 100644 index 6ecfad756..000000000 --- a/test_autofit/graphical/functionality/test_nested.py +++ /dev/null @@ -1,263 +0,0 @@ -import collections - -import pytest - -from autofit.graphical import utils - -tree_util = pytest.importorskip("jax.tree_util") - -NTuple = collections.namedtuple("NTuple", "first, last") - - -def jax_nested_zip(tree, *rest): - leaves, treedef = tree_util.tree_flatten(tree) - return zip(leaves, *(treedef.flatten_up_to(r) for r in rest)) - - -def jax_key_to_val(key): - if isinstance(key, tree_util.SequenceKey): - return key.idx - elif isinstance(key, (tree_util.DictKey, tree_util.FlattenedIndexKey)): - return key.key - elif isinstance(key, tree_util.GetAttrKey): - return key.name - return key - - -def jax_path_to_key(path): - return tuple(map(jax_key_to_val, path)) - - -def test_nested_getitem(): - obj = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} - - assert utils.nested_get(obj, ("b",)) == 2 - assert utils.nested_get(obj, ("c", "a")) == 1 - assert utils.nested_get(obj, ("d", 0)) == 3 - assert utils.nested_get(obj, ("d", 1, "e", 1)) == 5 - - -def test_nested_setitem(): - obj = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} - - utils.nested_set(obj, ("b",), 3) - assert utils.nested_get(obj, ("b",)) - - utils.nested_set(obj, ("c", "a"), 2) - assert utils.nested_get(obj, ("c", "a")) == 2 - - utils.nested_set(obj, ("d", 1, "e", 1), 6) - assert utils.nested_get(obj, ("d", 1, "e", 1)) == 6 - - with pytest.raises(TypeError): - utils.nested_set(obj, ("d", 0), 4) - - -def test_nested_order(): - obj1 = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": [4, 5]})} - obj2 = {"a": 1, "b": 2, "d": (3, {"e": [4, 5]}), "c": {"b": 2, "a": 1}} - - assert all(v1 == v2 for (v1, v2) in utils.nested_zip(obj1, obj2)) - assert all(utils.nested_filter(lambda x, y: x == y, obj1, obj2)) - assert list(utils.nested_zip(obj1)) == list(utils.nested_zip(obj2)) - assert list(utils.nested_zip(obj1, obj2)) == list(jax_nested_zip(obj1, obj2)) - - obj1 = {"b": 2, "a": 1, "c": {"b": 2, "a": 1}, "d": (3, {"e": NTuple(4, 5)})} - obj2 = {"a": 1, "b": 2, "d": (3, {"e": NTuple(4, 5)}), "c": {"b": 2, "a": 1}} - - assert all(v1 == v2 for (v1, v2) in utils.nested_zip(obj1, obj2)) - assert all(utils.nested_filter(lambda x, y: x == y, obj1, obj2)) - assert list(utils.nested_zip(obj1)) == list(utils.nested_zip(obj2)) - assert list(utils.nested_zip(obj1, obj2)) == list(jax_nested_zip(obj1, obj2)) - - -def test_nested_items(): - obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} - - for (k1, v1), (p2, v2) in zip( - utils.nested_items(obj1), tree_util.tree_flatten_with_path(obj1)[0] - ): - assert k1 == jax_path_to_key(p2) - assert v1 == v2 - - -def test_nested_filter(): - obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} - assert list(utils.nested_filter(lambda x: x % 2 == 0, obj1)) == [(2,), (4,), (2,)] - - obj1 = {"b": 2, "a": 1, "c": (3, {"e": [4, 5]}), "d": {"b": 2, "a": 1}} - assert list(utils.nested_filter(lambda x: x % 2 == 0, obj1)) == [(2,), (4,), (2,)] - - -def test_nested_map(): - obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} - obj2 = {"a": 2, "b": 4, "c": (6, {"e": [8, 10]}), "d": {"a": 2, "b": 4}} - obj12 = utils.nested_map(lambda x: x * 2, obj1) - assert obj12 == obj2 - - obj3 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": (4, 5)})} - obj4 = {"a": 2, "b": 4, "c": (6, {"e": (8, 10)}), "d": {"a": 2, "b": 4}} - obj32 = utils.nested_map(lambda x: x * 2, obj3) - assert obj32 == obj4 - - obj5 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": NTuple(4, 5)})} - obj6 = {"a": 2, "b": 4, "c": (6, {"e": NTuple(8, 10)}), "d": {"a": 2, "b": 4}} - obj52 = utils.nested_map(lambda x: x * 2, obj5) - assert obj52 == obj6 == obj4 - - assert obj32 != obj2 - assert obj52 != obj2 - - assert all( - utils.nested_iter( - utils.nested_map(lambda a, b, c: a == b == c, obj1, obj3, obj5) - ) - ) - assert all( - utils.nested_iter( - utils.nested_map(lambda a, b, c: a == b == c, obj2, obj4, obj6) - ) - ) - assert all(map(lambda x: x[0] == x[1] == x[2], utils.nested_zip(obj1, obj3, obj5))) - assert all( - map(lambda x: x[0] == x[1] == x[2], utils.nested_zip(obj2, obj32, obj52)) - ) - - -def test_nested_update(): - assert utils.nested_update([1, (2, 3), [3, 2, {1, 2}]], {2: "a"}) == [ - 1, - ("a", 3), - [3, "a", {1, "a"}], - ] - assert utils.nested_update([1, NTuple(2, 3), [3, 2, {1, 2}]], {2: "a"}) == [ - 1, - ("a", 3), - [3, "a", {1, "a"}], - ] - assert isinstance( - utils.nested_update([1, NTuple(2, 3), [3, 2, {1, 2}]], {2: "a"})[1], NTuple - ) - assert utils.nested_update([{2: 2}], {2: "a"}) == [{2: "a"}] - - -def test_nested_items(): - obj1 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": [4, 5]})} - obj2 = {"a": 2, "b": 4, "c": (6, {"e": [8, 10]}), "d": {"a": 2, "b": 4}} - - obj3 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": (4, 5)})} - obj4 = {"a": 2, "b": 4, "c": (6, {"e": (8, 10)}), "d": {"a": 2, "b": 4}} - - obj5 = {"b": 2, "a": 1, "d": {"b": 2, "a": 1}, "c": (3, {"e": NTuple(4, 5)})} - obj6 = {"a": 2, "b": 4, "c": (6, {"e": NTuple(8, 10)}), "d": {"a": 2, "b": 4}} - - for path, val in utils.nested_items(obj1): - assert ( - utils.nested_getitem(obj2, path) - == utils.nested_getitem(obj4, path) - == val * 2 - ) - - for path, val in utils.nested_items(obj3): - assert ( - utils.nested_getitem(obj4, path) - == utils.nested_getitem(obj6, path) - == val * 2 - ) - - for path, val in utils.nested_items(obj5): - assert ( - utils.nested_getitem(obj6, path) - == utils.nested_getitem(obj2, path) - == val * 2 - ) - - assert list(utils.nested_items([NTuple(1, 2), {2: 5, 1: 3}])) == [ - ((0, 0), 1), - ((0, 1), 2), - ((1, 1), 3), - ((1, 2), 5), - ] - - assert list(utils.nested_items([1, (2, 3), [3, {"a": 1, "b": 2}]])) == list( - utils.nested_items([1, (2, 3), [3, {"b": 2, "a": 1}]]) - ) - assert list( - utils.nested_items( - [ - 1, - (2, 3), - [ - 3, - { - "b": 2, - "a": 1, - }, - ], - ] - ) - ) == list(utils.nested_items([1, (2, 3), [3, {"b": 2, "a": 1}]])) - - obj1 = [ - 1, - (2, 3), - [ - 3, - { - "b": 2, - "a": 1, - }, - ], - ] - obj2 = [ - 1, - (2, 3), - [ - 3, - { - "a": 1, - "b": 2, - }, - ], - ] - obj3 = [ - 1, - NTuple(2, 3), - [ - 3, - { - "a": 1, - "b": 2, - }, - ], - ] - - if hasattr(tree_util, "tree_flatten_with_path"): - jax_flat = tree_util.tree_flatten_with_path(obj1)[0] - af_flat = utils.nested_items(obj2) - - for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat): - jkey = jax_path_to_key(jpath) - assert jkey == akey - assert jval == aval - assert ( - utils.nested_get(obj2, jkey) - == utils.nested_get(obj1, jkey) - == utils.nested_get(obj2, akey) - == utils.nested_get(obj1, akey) - ) - - jax_flat = tree_util.tree_flatten_with_path(obj2)[0] - af_flat = utils.nested_items(obj3) - for (jpath, jval), (akey, aval) in zip(jax_flat, af_flat): - jkey = jax_path_to_key(jpath) - assert jkey == akey - assert jval == aval - assert ( - utils.nested_get(obj2, jkey) - == utils.nested_get(obj1, jkey) - == utils.nested_get(obj2, akey) - == utils.nested_get(obj1, akey) - == utils.nested_get(obj3, jkey) - == utils.nested_get(obj3, akey) - ) diff --git a/test_autofit/jax/__init__.py b/test_autofit/jax/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/test_autofit/jax/test_enable_pytrees.py b/test_autofit/jax/test_enable_pytrees.py deleted file mode 100644 index 076f2dd3d..000000000 --- a/test_autofit/jax/test_enable_pytrees.py +++ /dev/null @@ -1,187 +0,0 @@ -"""Tests for ``autofit.jax.enable_pytrees`` / ``register_model``.""" -import pytest - -jax = pytest.importorskip("jax") -jnp = pytest.importorskip("jax.numpy") - -import autofit as af -from autofit.jax import enable_pytrees, register_model - - -@pytest.fixture(name="model") -def make_model(): - return af.Model( - af.ex.Gaussian, - centre=af.GaussianPrior(mean=1.0, sigma=1.0), - normalization=af.GaussianPrior(mean=2.0, sigma=1.0), - sigma=af.GaussianPrior(mean=3.0, sigma=1.0), - ) - - -def test_enable_pytrees_returns_true_when_jax_available(): - assert enable_pytrees() is True - - -def test_register_model_round_trip(model): - register_model(model) - instance = model.instance_from_prior_medians() - - leaves, treedef = jax.tree_util.tree_flatten(instance) - rebuilt = jax.tree_util.tree_unflatten(treedef, leaves) - - assert isinstance(rebuilt, af.ex.Gaussian) - assert float(rebuilt.centre) == float(instance.centre) - assert float(rebuilt.normalization) == float(instance.normalization) - assert float(rebuilt.sigma) == float(instance.sigma) - - -def test_register_model_works_under_jit(model): - register_model(model) - instance = model.instance_from_prior_medians() - - @jax.jit - def total(inst): - return inst.centre + inst.normalization + inst.sigma - - assert float(total(instance)) == pytest.approx( - instance.centre + instance.normalization + instance.sigma - ) - - -def test_register_model_keeps_constants_static(): - """Constants from the model definition must NOT be traced under JIT. - - Galaxy.redshift is a common case: it sits in the constructor signature - but gets a fixed value via ``af.Constant``. If it became a JAX tracer, - code paths like ``sorted(galaxies, key=lambda g: g.redshift)`` would - blow up with ``TracerBoolConversionError``. The ``aux_data`` partition - in ``register_model`` keeps it as a Python float. - """ - class Holder: - def __init__(self, redshift, scale): - self.redshift = redshift - self.scale = scale - - model = af.Model( - Holder, - redshift=0.5, - scale=af.GaussianPrior(mean=1.0, sigma=1.0), - ) - register_model(model) - instance = model.instance_from_prior_medians() - assert isinstance(instance.redshift, float) - - @jax.jit - def use_redshift_for_control_flow(inst): - # `if` on a traced value would raise TracerBoolConversionError; - # this only works if redshift is kept static via aux_data. - if inst.redshift > 0: - return inst.scale * 2 - return inst.scale - - result = use_redshift_for_control_flow(instance) - assert float(result) == pytest.approx(2.0 * instance.scale) - - -def test_register_model_keeps_kwarg_constants_static(): - """Constant ``**kwargs`` attributes must stay in aux_data, not children. - - ``Galaxy.__init__(self, redshift, **kwargs)`` stores every kwarg via - ``setattr``. A concrete object passed as a kwarg (e.g. a ``Pixelization``) - is an instance attribute but NOT a constructor argument, so the old - flatten logic routed it to ``children`` and it became a JAX tracer. - Downstream ``isinstance(x, Pixelization)`` checks then returned False. - This test exercises the exact pattern. - """ - class Marker: - pass - - class KwargHolder: - def __init__(self, redshift, **kwargs): - self.redshift = redshift - for k, v in kwargs.items(): - setattr(self, k, v) - - marker = Marker() - model = af.Model( - KwargHolder, - redshift=0.5, - marker=marker, - scale=af.GaussianPrior(mean=1.0, sigma=1.0), - ) - register_model(model) - instance = model.instance_from_prior_medians() - assert instance.marker is marker - - @jax.jit - def use_marker_isinstance(inst): - # isinstance on a tracer would return False; this only works if - # `marker` is kept concrete via aux_data. - if isinstance(inst.marker, Marker): - return inst.scale * 2 - return inst.scale - - result = use_marker_isinstance(instance) - assert float(result) == pytest.approx(2.0 * instance.scale) - - -def test_register_model_traces_tuple_prior_attributes(): - """``TuplePrior``-backed attributes must be routed into JAX children so - gradients flow through paired priors like ``centre=(x, y)`` and - ``ell_comps=(e1, e2)``. - - Mirrors real-world MGE / Isothermal / ExternalShear usage where the - paired priors are the majority of the free parameters. Prior to the - fix, ``TuplePrior`` failed the ``(Prior, AbstractPriorModel)`` - isinstance check in ``register_model``, so the resolved tuple was - frozen in ``aux_data`` and ``jax.value_and_grad`` returned gradients - only for the non-tuple attributes. - """ - class Twin: - def __init__(self, centre, amplitude): - self.centre = centre - self.amplitude = amplitude - - model = af.Model( - Twin, - centre=af.TuplePrior( - centre_0=af.GaussianPrior(mean=0.5, sigma=1.0), - centre_1=af.GaussianPrior(mean=-0.5, sigma=1.0), - ), - amplitude=af.GaussianPrior(mean=1.0, sigma=1.0), - ) - register_model(model) - instance = model.instance_from_prior_medians() - params_tree = jax.tree_util.tree_map(jnp.asarray, instance) - - leaves = jax.tree_util.tree_leaves(params_tree) - assert len(leaves) == 3 # centre[0], centre[1], amplitude - - def loss(inst): - cx, cy = inst.centre - return cx * cx + cy * cy + inst.amplitude - - _, grad = jax.value_and_grad(loss)(params_tree) - flat_grad = jnp.concatenate( - [jnp.asarray(l).ravel() for l in jax.tree_util.tree_leaves(grad)] - ) - assert jnp.all(jnp.isfinite(flat_grad)) - assert flat_grad.size == 3 - - -def test_enable_pytrees_idempotent(): - assert enable_pytrees() is True - assert enable_pytrees() is True - - -def test_collection_round_trip(model): - register_model(model) - collection = af.Collection(g1=model, g2=model) - register_model(collection) - - instance = collection.instance_from_prior_medians() - leaves, treedef = jax.tree_util.tree_flatten(instance) - rebuilt = jax.tree_util.tree_unflatten(treedef, leaves) - - assert float(rebuilt.g1.centre) == float(instance.g1.centre) - assert float(rebuilt.g2.sigma) == float(instance.g2.sigma) diff --git a/test_autofit/jax/test_jit.py b/test_autofit/jax/test_jit.py deleted file mode 100644 index 9059ca7bb..000000000 --- a/test_autofit/jax/test_jit.py +++ /dev/null @@ -1,62 +0,0 @@ -import pickle - -import autofit as af - -from test_autofit.graphical.gaussian.model import Analysis, Gaussian, make_data -from test_autofit.graphical.gaussian import model as model_module - -import pytest - -# jax = pytest.importorskip("jax") -# -# -# -# @pytest.fixture(autouse=True, name="model") -# def make_model(): -# return af.Model(Gaussian) -# -# -# @pytest.fixture(name="analysis") -# def make_analysis(): -# import jax.numpy as jnp -# x = jnp.arange(100) -# y = make_data(Gaussian(centre=50.0, normalization=25.0, sigma=10.0), x) -# return Analysis(x, y) - - -# @pytest.fixture(name="instance") -# def make_instance(): -# return Gaussian() -# -# -# def test_jit_likelihood(analysis, instance): -# -# import jax -# -# instance = Gaussian() -# -# jitted = jax.jit(analysis.log_likelihood_function) -# -# assert jitted(instance) == analysis.log_likelihood_function(instance) - - -# def test_jit_dynesty_static( -# analysis, -# model, -# monkeypatch, -# ): -# monkeypatch.setattr( -# jax_wrapper, -# "use_jax", -# True, -# ) -# search = af.DynestyStatic( -# use_gradient=True, -# number_of_cores=1, -# maxcall=1, -# ) -# -# print(search.fit(model=model, analysis=analysis)) -# -# loaded = pickle.loads(pickle.dumps(search)) -# assert isinstance(loaded, af.DynestyStatic) diff --git a/test_autofit/jax/test_pytrees.py b/test_autofit/jax/test_pytrees.py deleted file mode 100644 index 89ca1dd36..000000000 --- a/test_autofit/jax/test_pytrees.py +++ /dev/null @@ -1,152 +0,0 @@ -import numpy as np -import pytest - -# importorskip MUST run before any jax import — skips the whole module -# cleanly when jax isn't installed (Python <3.11 in our [jax] extra gate). -jax = pytest.importorskip("jax") -jnp = pytest.importorskip("jax.numpy") -from jax.tree_util import register_pytree_node_class - -import autofit as af -from autofit import UniformPrior - -UniformPrior = register_pytree_node_class(UniformPrior) -GaussianPrior = register_pytree_node_class(af.GaussianPrior) -TruncatedGaussianPrior = register_pytree_node_class(af.TruncatedGaussianPrior) -Collection = register_pytree_node_class(af.Collection) -Model = register_pytree_node_class(af.Model) -ModelInstance = register_pytree_node_class(af.ModelInstance) - - -@pytest.fixture(name="recreate") -def make_recreate(): - """jax-pytree-roundtrip fixture, scoped to this test module since it's - the only consumer and it must not pollute the numpy-only top-level - conftest.""" - - def _recreate(o): - flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)] - children, aux_data = flatten_func(o) - return unflatten_func(aux_data, children) - - return _recreate - -@pytest.fixture(name="gaussian") -def make_gaussian(): - return af.ex.Gaussian(centre=1.0, sigma=1.0, normalization=1.0) - - -@pytest.fixture(autouse=True) -def patch_np(monkeypatch): - monkeypatch.setattr(af.example.model, "np", jnp) - - -def classic(gaussian, size=1000): - return list(map(gaussian.f, np.arange(size))) - - -def vmapped(gaussian, size=1000): - f = jax.vmap(gaussian.f) - return list(f(np.arange(size))) - - -def test_gaussian_prior(recreate): - - prior = TruncatedGaussianPrior(mean=1.0, sigma=1.0) - - new = recreate(prior) - - assert new.mean == prior.mean - assert new.sigma == prior.sigma - assert new.id == prior.id - - assert new.lower_limit == prior.lower_limit - assert new.upper_limit == prior.upper_limit - - -@pytest.fixture(name="model") -def _model(): - return Model( - af.ex.Gaussian, - centre=af.GaussianPrior(mean=1.0, sigma=1.0), - normalization=af.GaussianPrior(mean=1.0, sigma=1.0), - sigma=af.GaussianPrior(mean=1.0, sigma=1.0), - ) - - -def test_model(model, recreate): - new = recreate(model) - assert new.cls == af.ex.Gaussian - - centre = new.centre - assert centre.mean == model.centre.mean - assert centre.sigma == model.centre.sigma - assert centre.id == model.centre.id - - -# def test_instance(model, recreate): -# instance = model.instance_from_prior_medians() -# new = recreate(instance) -# -# assert isinstance(new, af.ex.Gaussian) -# -# assert new.centre == instance.centre -# assert new.normalization == instance.normalization -# assert new.sigma == instance.sigma - - -def test_uniform_prior(recreate): - prior = af.UniformPrior(lower_limit=0.0, upper_limit=1.0) - - new = recreate(prior) - - assert new.lower_limit == prior.lower_limit - assert new.upper_limit == prior.upper_limit - assert new.id == prior.id - - -def test_model_instance(model, recreate): - collection = Collection(gaussian=model) - instance = collection.instance_from_prior_medians() - new = recreate(instance) - - assert isinstance(new, ModelInstance) - assert isinstance(new.gaussian, af.ex.Gaussian) - - -def test_collection(model, recreate): - collection = Collection(gaussian=model) - new = recreate(collection) - - assert isinstance(new, Collection) - assert isinstance(new.gaussian, Model) - - assert new.gaussian.cls == af.ex.Gaussian - - centre = new.gaussian.centre - assert centre.mean == model.centre.mean - assert centre.sigma == model.centre.sigma - assert centre.id == model.centre.id - - -class KwargClass: - """ - @DynamicAttrs - """ - - def __init__(self, **kwargs): - self.__dict__.update(kwargs) - - -# def test_kwargs(recreate): -# -# model = Model(KwargClass, a=1, b=2) -# instance = model.instance_from_prior_medians() -# -# assert instance.a == 1 -# assert instance.b == 2 -# -# new = recreate(instance) -# -# assert new.a == instance.a -# assert new.b == instance.b