diff --git a/scripts/jax_assertions/fitness_dispatch.py b/scripts/jax_assertions/fitness_dispatch.py index a0a1a48..1285a8b 100644 --- a/scripts/jax_assertions/fitness_dispatch.py +++ b/scripts/jax_assertions/fitness_dispatch.py @@ -91,6 +91,42 @@ def assert_fit_for_visualization_dispatches_through_jit_when_flag_set(): assert jnp.allclose(result_2, jnp.asarray(6.0)) +def assert_use_jax_true_implicitly_turns_on_visualization(): + """Sentinel default: ``Analysis(use_jax=True)`` with no explicit + ``use_jax_for_visualization`` argument resolves to the JIT visualization + path. Validates that PyAutoFit's ``Analysis.__init__`` treats + ``use_jax_for_visualization=None`` as "follow ``use_jax``".""" + analysis = af.Analysis(use_jax=True) + assert analysis._use_jax is True + assert analysis._use_jax_for_visualization is True + + +def assert_explicit_false_opts_out_when_use_jax_true(): + """Users can force the eager NumPy plotter alongside JAX likelihoods by + passing ``use_jax_for_visualization=False`` explicitly.""" + analysis = af.Analysis(use_jax=True, use_jax_for_visualization=False) + assert analysis._use_jax is True + assert analysis._use_jax_for_visualization is False + + +def assert_explicit_none_resolves_to_use_jax(): + """Passing ``None`` explicitly is identical to omitting the argument.""" + analysis = af.Analysis(use_jax=True, use_jax_for_visualization=None) + assert analysis._use_jax_for_visualization is True + + +def assert_use_jax_true_jit_dispatch_via_sentinel_default(): + """End-to-end: ``Analysis(use_jax=True)`` with sentinel default still wires + the JIT dispatch in ``fit_for_visualization`` — the previous assertion at + the top of this section did the equivalent with the flag set explicitly; + this one confirms the same behaviour when relying on the new default.""" + analysis = _JitFittableAnalysis(use_jax=True) + assert getattr(analysis, "_jitted_fit_from", None) is None + result = analysis.fit_for_visualization(instance=2.0) + assert analysis._jitted_fit_from is not None + assert jnp.allclose(result, jnp.asarray(4.0)) + + class _ArrayAnalysis(af.Analysis): def log_likelihood_function(self, instance): return -float( @@ -140,5 +176,9 @@ def assert_array_optimisation_returns_jnp_instance(): assert_vmap_takes_precedence_over_jit() assert_pickle_strips_jax_cached_attrs() assert_fit_for_visualization_dispatches_through_jit_when_flag_set() + assert_use_jax_true_implicitly_turns_on_visualization() + assert_explicit_false_opts_out_when_use_jax_true() + assert_explicit_none_resolves_to_use_jax() + assert_use_jax_true_jit_dispatch_via_sentinel_default() assert_array_optimisation_returns_jnp_instance() print("fitness_dispatch: all assertions passed")