@@ -36,23 +36,9 @@ class Analysis(ABC):
3636 def __init__ (
3737 self ,
3838 use_jax : bool = False ,
39- use_jax_for_visualization : Optional [ bool ] = None ,
39+ use_jax_for_visualization : bool = False ,
4040 ** kwargs ,
4141 ):
42- """
43- Parameters
44- ----------
45- use_jax
46- Run the likelihood through ``jax.jit`` for the fast path. When JAX
47- is unavailable this silently falls back to numpy with a warning.
48- use_jax_for_visualization
49- Whether ``fit_for_visualization`` should dispatch through the
50- ``jax.jit``-cached path. ``None`` (default) follows ``use_jax`` —
51- users who set ``use_jax=True`` automatically get JIT visualization.
52- Pass ``False`` to force the eager NumPy plotter even when
53- ``use_jax=True``; pass ``True`` to opt in explicitly. Passing
54- ``True`` while ``use_jax=False`` logs a warning and disables it.
55- """
5642 import os
5743 if os .environ .get ("PYAUTO_DISABLE_JAX" ) == "1" :
5844 use_jax = False
@@ -82,9 +68,6 @@ def __init__(
8268 use_jax = False
8369 use_jax_for_visualization = False
8470
85- if use_jax_for_visualization is None :
86- use_jax_for_visualization = use_jax
87-
8871 if use_jax_for_visualization and not use_jax :
8972 logger .warning (
9073 "use_jax_for_visualization=True requires use_jax=True; "
@@ -100,40 +83,28 @@ def fit_for_visualization(self, instance):
10083 """
10184 Build the fit used by the visualizer.
10285
103- Dispatch over ``self.fit_from`` with a ``jax.jit`` fast path that
104- follows ``use_jax`` by default:
86+ Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:
10587
106- * ``self._use_jax_for_visualization`` is ``False`` — plain
107- ``self.fit_from(instance)``. Untouched by JAX. This is the
108- resolved state when ``use_jax=False`` (the parameter default),
109- or when the user explicitly passed
110- ``use_jax_for_visualization=False`` to opt out.
111- * ``self._use_jax_for_visualization`` is ``True`` — lazily construct
88+ * ``use_jax_for_visualization=False`` (default) — plain
89+ ``self.fit_from(instance)``. Untouched by JAX.
90+ * ``use_jax_for_visualization=True`` — lazily construct
11291 ``jax.jit(self.fit_from)`` on the first call and cache it on the
11392 instance as ``_jitted_fit_from``, then call that for every
11493 subsequent visualization. The first call pays the compile cost;
115- subsequent calls reuse the cached compiled function. This is the
116- resolved state when ``use_jax=True`` (the sentinel default
117- ``use_jax_for_visualization=None`` follows ``use_jax``).
94+ subsequent calls reuse the cached compiled function.
11895
11996 Caching is per-``Analysis`` instance so each analysis gets its own
12097 compiled function keyed off that instance's closed-over state
12198 (``self.dataset``, ``self.settings``, etc. — these ride as pytree
12299 aux data via ``register_instance_pytree(FitImaging, no_flatten=...)``
123100 in PyAutoLens).
124101
125- ``fit_from`` is defined by Analysis subclasses (e.g. ``AnalysisImaging``),
126- not the base class — this method is only callable on subclasses that
127- provide it. Downstream visualizers should prefer this over calling
128- ``fit_from`` directly so the JIT seam stays in one place.
129-
130102 For the JIT path to succeed, the ``Fit*`` return type (and every
131103 nested autoarray / galaxy / lens type it carries) must be pytree-
132104 registered. That wiring lives in each analysis subclass (see
133105 ``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
134- Variants that have not yet been pytree-audited must pass
135- ``use_jax_for_visualization=False`` explicitly when constructing
136- the analysis (or simply leave ``use_jax=False``).
106+ Variants that have not yet been pytree-audited must leave
107+ ``use_jax_for_visualization`` at its default of ``False``.
137108 """
138109 if not self ._use_jax_for_visualization :
139110 return self .fit_from (instance = instance )
0 commit comments