Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions scripts/jax_assertions/fitness_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,42 @@ def assert_fit_for_visualization_dispatches_through_jit_when_flag_set():
assert jnp.allclose(result_2, jnp.asarray(6.0))


def assert_use_jax_true_implicitly_turns_on_visualization():
"""Sentinel default: ``Analysis(use_jax=True)`` with no explicit
``use_jax_for_visualization`` argument resolves to the JIT visualization
path. Validates that PyAutoFit's ``Analysis.__init__`` treats
``use_jax_for_visualization=None`` as "follow ``use_jax``"."""
analysis = af.Analysis(use_jax=True)
assert analysis._use_jax is True
assert analysis._use_jax_for_visualization is True


def assert_explicit_false_opts_out_when_use_jax_true():
"""Users can force the eager NumPy plotter alongside JAX likelihoods by
passing ``use_jax_for_visualization=False`` explicitly."""
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=False)
assert analysis._use_jax is True
assert analysis._use_jax_for_visualization is False


def assert_explicit_none_resolves_to_use_jax():
"""Passing ``None`` explicitly is identical to omitting the argument."""
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=None)
assert analysis._use_jax_for_visualization is True


def assert_use_jax_true_jit_dispatch_via_sentinel_default():
"""End-to-end: ``Analysis(use_jax=True)`` with sentinel default still wires
the JIT dispatch in ``fit_for_visualization`` — the previous assertion at
the top of this section did the equivalent with the flag set explicitly;
this one confirms the same behaviour when relying on the new default."""
analysis = _JitFittableAnalysis(use_jax=True)
assert getattr(analysis, "_jitted_fit_from", None) is None
result = analysis.fit_for_visualization(instance=2.0)
assert analysis._jitted_fit_from is not None
assert jnp.allclose(result, jnp.asarray(4.0))


class _ArrayAnalysis(af.Analysis):
def log_likelihood_function(self, instance):
return -float(
Expand Down Expand Up @@ -140,5 +176,9 @@ def assert_array_optimisation_returns_jnp_instance():
assert_vmap_takes_precedence_over_jit()
assert_pickle_strips_jax_cached_attrs()
assert_fit_for_visualization_dispatches_through_jit_when_flag_set()
assert_use_jax_true_implicitly_turns_on_visualization()
assert_explicit_false_opts_out_when_use_jax_true()
assert_explicit_none_resolves_to_use_jax()
assert_use_jax_true_jit_dispatch_via_sentinel_default()
assert_array_optimisation_returns_jnp_instance()
print("fitness_dispatch: all assertions passed")
Loading