1- """Tests for the ``use_jax_for_visualization`` flag on ``Analysis``."""
1+ """Tests for the visualization path on ``Analysis``.
2+
3+ The ``use_jax_for_visualization`` flag has been removed — visualization
4+ now always follows ``use_jax``. These tests verify the simplified
5+ ``fit_for_visualization`` dispatch and the ``supports_jax_visualization``
6+ property.
7+ """
28
39import importlib .util
410
814
915
1016def _jax_installed () -> bool :
11- """Check jax availability without importing it (per numpy-only rule)."""
1217 return importlib .util .find_spec ("jax" ) is not None
1318
1419
1520class _FittableAnalysis (af .Analysis ):
16- """Minimal Analysis subclass with a trivial ``fit_from`` for dispatch tests."""
1721
1822 def __init__ (self , ** kwargs ):
1923 super ().__init__ (** kwargs )
@@ -27,52 +31,37 @@ def fit_from(self, instance):
2731 return ("fit" , instance )
2832
2933
30- def test_default_flag_is_false ():
34+ def test_default_flags ():
3135 analysis = af .Analysis ()
3236 assert analysis ._use_jax is False
33- assert analysis ._use_jax_for_visualization is False
3437 assert analysis .supports_jax_visualization is False
3538
3639
37- def test_flag_requires_use_jax (caplog ):
38- with caplog .at_level ("WARNING" ):
39- analysis = af .Analysis (use_jax = False , use_jax_for_visualization = True )
40- assert analysis ._use_jax_for_visualization is False
41- assert any ("requires use_jax=True" in r .message for r in caplog .records )
42-
43-
44- @pytest .mark .skipif (not _jax_installed (), reason = "jax not installed; fallback path tested below" )
45- def test_flag_accepted_when_use_jax_true ():
46- analysis = af .Analysis (use_jax = True , use_jax_for_visualization = True )
40+ @pytest .mark .skipif (not _jax_installed (), reason = "jax not installed" )
41+ def test_use_jax_enables_jax_visualization ():
42+ analysis = af .Analysis (use_jax = True )
4743 assert analysis ._use_jax is True
48- assert analysis ._use_jax_for_visualization is True
4944 assert analysis .supports_jax_visualization is True
5045
5146
52- @pytest .mark .skipif (_jax_installed (), reason = "jax installed; happy path tested above" )
53- def test_use_jax_true_falls_back_to_numpy_when_jax_missing (recwarn ):
54- """When jax isn't installed, use_jax=True should silently downgrade
55- to use_jax=False after emitting a UserWarning. Affects 3.9/3.10
56- where the [jax] extra is gated out."""
57- analysis = af .Analysis (use_jax = True , use_jax_for_visualization = True )
47+ @pytest .mark .skipif (_jax_installed (), reason = "jax installed" )
48+ def test_use_jax_true_falls_back_when_jax_missing (recwarn ):
49+ analysis = af .Analysis (use_jax = True )
5850 assert analysis ._use_jax is False
59- assert analysis ._use_jax_for_visualization is False
6051 assert any ("JAX is not installed" in str (w .message ) for w in recwarn )
6152
6253
63- def test_pyauto_disable_jax_env_var_clears_both_flags (monkeypatch ):
54+ def test_pyauto_disable_jax_env_var (monkeypatch ):
6455 monkeypatch .setenv ("PYAUTO_DISABLE_JAX" , "1" )
65- analysis = af .Analysis (use_jax = True , use_jax_for_visualization = True )
56+ analysis = af .Analysis (use_jax = True )
6657 assert analysis ._use_jax is False
67- assert analysis ._use_jax_for_visualization is False
6858
6959
70- def test_fit_for_visualization_works_without_flag ():
60+ def test_fit_for_visualization_delegates_to_fit_from ():
7161 analysis = _FittableAnalysis ()
7262 result = analysis .fit_for_visualization (instance = "sentinel" )
7363 assert result == ("fit" , "sentinel" )
7464 assert analysis .fit_from_calls == 1
75- assert getattr (analysis , "_jitted_fit_from" , None ) is None
7665
7766
7867def test_subclass_can_override_supports_jax_visualization ():
@@ -82,5 +71,4 @@ def supports_jax_visualization(self):
8271 return True
8372
8473 analysis = ForcedAnalysis ()
85- assert analysis ._use_jax_for_visualization is False
8674 assert analysis .supports_jax_visualization is True
0 commit comments