Skip to content

Commit 2ddfa62

Browse files
authored
Merge pull request #1280 from PyAutoLabs/feature/jax-viz-default-broken
revert: default use_jax_for_visualization to False (reverts #1278)
2 parents 035f93a + 0b895a6 commit 2ddfa62

2 files changed

Lines changed: 8 additions & 48 deletions

File tree

autofit/non_linear/analysis/analysis.py

Lines changed: 8 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -36,23 +36,9 @@ class Analysis(ABC):
3636
def __init__(
3737
self,
3838
use_jax: bool = False,
39-
use_jax_for_visualization: Optional[bool] = None,
39+
use_jax_for_visualization: bool = False,
4040
**kwargs,
4141
):
42-
"""
43-
Parameters
44-
----------
45-
use_jax
46-
Run the likelihood through ``jax.jit`` for the fast path. When JAX
47-
is unavailable this silently falls back to numpy with a warning.
48-
use_jax_for_visualization
49-
Whether ``fit_for_visualization`` should dispatch through the
50-
``jax.jit``-cached path. ``None`` (default) follows ``use_jax`` —
51-
users who set ``use_jax=True`` automatically get JIT visualization.
52-
Pass ``False`` to force the eager NumPy plotter even when
53-
``use_jax=True``; pass ``True`` to opt in explicitly. Passing
54-
``True`` while ``use_jax=False`` logs a warning and disables it.
55-
"""
5642
import os
5743
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
5844
use_jax = False
@@ -82,9 +68,6 @@ def __init__(
8268
use_jax = False
8369
use_jax_for_visualization = False
8470

85-
if use_jax_for_visualization is None:
86-
use_jax_for_visualization = use_jax
87-
8871
if use_jax_for_visualization and not use_jax:
8972
logger.warning(
9073
"use_jax_for_visualization=True requires use_jax=True; "
@@ -100,40 +83,28 @@ def fit_for_visualization(self, instance):
10083
"""
10184
Build the fit used by the visualizer.
10285
103-
Dispatch over ``self.fit_from`` with a ``jax.jit`` fast path that
104-
follows ``use_jax`` by default:
86+
Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:
10587
106-
* ``self._use_jax_for_visualization`` is ``False`` — plain
107-
``self.fit_from(instance)``. Untouched by JAX. This is the
108-
resolved state when ``use_jax=False`` (the parameter default),
109-
or when the user explicitly passed
110-
``use_jax_for_visualization=False`` to opt out.
111-
* ``self._use_jax_for_visualization`` is ``True`` — lazily construct
88+
* ``use_jax_for_visualization=False`` (default) — plain
89+
``self.fit_from(instance)``. Untouched by JAX.
90+
* ``use_jax_for_visualization=True`` — lazily construct
11291
``jax.jit(self.fit_from)`` on the first call and cache it on the
11392
instance as ``_jitted_fit_from``, then call that for every
11493
subsequent visualization. The first call pays the compile cost;
115-
subsequent calls reuse the cached compiled function. This is the
116-
resolved state when ``use_jax=True`` (the sentinel default
117-
``use_jax_for_visualization=None`` follows ``use_jax``).
94+
subsequent calls reuse the cached compiled function.
11895
11996
Caching is per-``Analysis`` instance so each analysis gets its own
12097
compiled function keyed off that instance's closed-over state
12198
(``self.dataset``, ``self.settings``, etc. — these ride as pytree
12299
aux data via ``register_instance_pytree(FitImaging, no_flatten=...)``
123100
in PyAutoLens).
124101
125-
``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``),
126-
not the base class — this method is only callable on subclasses that
127-
provide it. Downstream visualizers should prefer this over calling
128-
``fit_from`` directly so the JIT seam stays in one place.
129-
130102
For the JIT path to succeed, the ``Fit*`` return type (and every
131103
nested autoarray / galaxy / lens type it carries) must be pytree-
132104
registered. That wiring lives in each analysis subclass (see
133105
``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
134-
Variants that have not yet been pytree-audited must pass
135-
``use_jax_for_visualization=False`` explicitly when constructing
136-
the analysis (or simply leave ``use_jax=False``).
106+
Variants that have not yet been pytree-audited must leave
107+
``use_jax_for_visualization`` at its default of ``False``.
137108
"""
138109
if not self._use_jax_for_visualization:
139110
return self.fit_from(instance=instance)

test_autofit/analysis/test_use_jax_for_visualization.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,6 @@ def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
6767
assert analysis._use_jax_for_visualization is False
6868

6969

70-
def test_pyauto_disable_jax_overrides_sentinel_default(monkeypatch):
71-
"""PYAUTO_DISABLE_JAX=1 must still force both off even when the user
72-
constructs Analysis(use_jax=True) and lets the sentinel resolve. This is
73-
a numpy-only check — JAX-conditional sentinel-resolution assertions live
74-
in autofit_workspace_test/scripts/jax_assertions/fitness_dispatch.py."""
75-
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
76-
analysis = af.Analysis(use_jax=True)
77-
assert analysis._use_jax is False
78-
assert analysis._use_jax_for_visualization is False
79-
80-
8170
def test_fit_for_visualization_works_without_flag():
8271
analysis = _FittableAnalysis()
8372
result = analysis.fit_for_visualization(instance="sentinel")

0 commit comments

Comments
 (0)