Skip to content

Commit aa01bd8

Browse files
authored
Merge pull request #1297 from PyAutoLabs/feature/unify-jax-visualization
Remove use_jax_for_visualization; add visualization warmup
2 parents 1bfd0de + cb01204 commit aa01bd8

3 files changed

Lines changed: 53 additions & 83 deletions

File tree

autofit/non_linear/analysis/analysis.py

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,12 @@ class Analysis(ABC):
5858
def __init__(
5959
self,
6060
use_jax: bool = False,
61-
use_jax_for_visualization: bool = False,
6261
**kwargs,
6362
):
6463
import os
6564
if os.environ.get("PYAUTO_DISABLE_JAX") == "1":
6665
use_jax = False
67-
use_jax_for_visualization = False
6866

69-
# If the user requested JAX but it isn't installed (e.g. Python <3.11
70-
# without the [jax] extra), fall back to numpy with a loud warning
71-
# rather than crashing later when the analysis tries to jit-compile.
7267
if use_jax:
7368
import importlib.util
7469
import warnings
@@ -88,55 +83,20 @@ def __init__(
8883
stacklevel=2,
8984
)
9085
use_jax = False
91-
use_jax_for_visualization = False
92-
93-
if use_jax_for_visualization and not use_jax:
94-
logger.warning(
95-
"use_jax_for_visualization=True requires use_jax=True; "
96-
"disabling use_jax_for_visualization."
97-
)
98-
use_jax_for_visualization = False
9986

10087
self._use_jax = use_jax
101-
self._use_jax_for_visualization = use_jax_for_visualization
10288
self.kwargs = kwargs
10389

10490
def fit_for_visualization(self, instance):
10591
"""
10692
Build the fit used by the visualizer.
10793
108-
Dispatch over ``self.fit_from`` with an opt-in ``jax.jit`` fast path:
109-
110-
* ``use_jax_for_visualization=False`` (default) — plain
111-
``self.fit_from(instance)``. Untouched by JAX.
112-
* ``use_jax_for_visualization=True`` — lazily construct
113-
``jax.jit(self.fit_from)`` on the first call and cache it on the
114-
instance as ``_jitted_fit_from``, then call that for every
115-
subsequent visualization. The first call pays the compile cost;
116-
subsequent calls reuse the cached compiled function.
117-
118-
Caching is per-``Analysis`` instance so each analysis gets its own
119-
compiled function keyed off that instance's closed-over state
120-
(``self.dataset``, ``self.settings``, etc. — these ride as pytree
121-
aux data via ``register_instance_pytree(FitImaging, no_flatten=...)``
122-
in PyAutoLens).
123-
124-
For the JIT path to succeed, the ``Fit*`` return type (and every
125-
nested autoarray / galaxy / lens type it carries) must be pytree-
126-
registered. That wiring lives in each analysis subclass (see
127-
``AnalysisImaging._register_fit_imaging_pytrees`` in PyAutoLens).
128-
Variants that have not yet been pytree-audited must leave
129-
``use_jax_for_visualization`` at its default of ``False``.
94+
Delegates to ``self.fit_from(instance)``. When ``use_jax=True``,
95+
the profile evaluations inside ``fit_from`` dispatch to JAX via
96+
the decorator chain. The per-function JIT caches warm up on the
97+
first call and are reused on all subsequent quick updates.
13098
"""
131-
if not self._use_jax_for_visualization:
132-
return self.fit_from(instance=instance)
133-
134-
if getattr(self, "_jitted_fit_from", None) is None:
135-
import jax
136-
137-
self._jitted_fit_from = jax.jit(self.fit_from)
138-
139-
return self._jitted_fit_from(instance)
99+
return self.fit_from(instance=instance)
140100

141101
def __getattr__(self, item: str):
142102
"""
@@ -444,15 +404,8 @@ def supports_background_update(self) -> bool:
444404

445405
@property
446406
def supports_jax_visualization(self) -> bool:
447-
"""
448-
Whether the visualizer can work directly with JAX arrays.
449-
450-
Derived from the ``use_jax_for_visualization`` flag passed at
451-
construction time. Subclasses may override to force a specific
452-
answer (e.g. an Analysis that has been audited to support JAX
453-
visualization unconditionally).
454-
"""
455-
return self._use_jax_for_visualization
407+
"""Whether the visualizer can work directly with JAX arrays."""
408+
return self._use_jax
456409

457410
def perform_quick_update(self, paths, instance):
458411
raise NotImplementedError

autofit/non_linear/fitness.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,35 @@ def __init__(
163163
if self.paths is not None:
164164
self.check_log_likelihood(fitness=self)
165165

166+
if (
167+
self.iterations_per_quick_update is not None
168+
and self._xp.__name__.startswith("jax")
169+
):
170+
self._warmup_visualization()
171+
172+
def _warmup_visualization(self):
173+
"""Pre-compile the JAX operations used by ``fit_for_visualization``.
174+
175+
The first call to ``fit_for_visualization`` triggers ~200 small
176+
per-function JAX JIT compilations (one per profile method per
177+
decorator). Running them here moves that cost to search setup
178+
so every quick update during sampling is fast.
179+
"""
180+
logger.info(
181+
"Warming up visualization (one-time JAX compilation)..."
182+
)
183+
try:
184+
instance = self.model.instance_from_prior_medians()
185+
fit = self.analysis.fit_for_visualization(instance=instance)
186+
_ = fit.model_data
187+
except Exception:
188+
logger.warning(
189+
"Visualization warm-up failed (non-fatal); "
190+
"first quick update may be slow."
191+
)
192+
else:
193+
logger.info("Visualization warm-up complete.")
194+
166195
@property
167196
def _xp(self):
168197
return self.analysis._xp
Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
"""Tests for the ``use_jax_for_visualization`` flag on ``Analysis``."""
1+
"""Tests for the visualization path on ``Analysis``.
2+
3+
The ``use_jax_for_visualization`` flag has been removed — visualization
4+
now always follows ``use_jax``. These tests verify the simplified
5+
``fit_for_visualization`` dispatch and the ``supports_jax_visualization``
6+
property.
7+
"""
28

39
import importlib.util
410

@@ -8,12 +14,10 @@
814

915

1016
def _jax_installed() -> bool:
11-
"""Check jax availability without importing it (per numpy-only rule)."""
1217
return importlib.util.find_spec("jax") is not None
1318

1419

1520
class _FittableAnalysis(af.Analysis):
16-
"""Minimal Analysis subclass with a trivial ``fit_from`` for dispatch tests."""
1721

1822
def __init__(self, **kwargs):
1923
super().__init__(**kwargs)
@@ -27,52 +31,37 @@ def fit_from(self, instance):
2731
return ("fit", instance)
2832

2933

30-
def test_default_flag_is_false():
34+
def test_default_flags():
3135
analysis = af.Analysis()
3236
assert analysis._use_jax is False
33-
assert analysis._use_jax_for_visualization is False
3437
assert analysis.supports_jax_visualization is False
3538

3639

37-
def test_flag_requires_use_jax(caplog):
38-
with caplog.at_level("WARNING"):
39-
analysis = af.Analysis(use_jax=False, use_jax_for_visualization=True)
40-
assert analysis._use_jax_for_visualization is False
41-
assert any("requires use_jax=True" in r.message for r in caplog.records)
42-
43-
44-
@pytest.mark.skipif(not _jax_installed(), reason="jax not installed; fallback path tested below")
45-
def test_flag_accepted_when_use_jax_true():
46-
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
40+
@pytest.mark.skipif(not _jax_installed(), reason="jax not installed")
41+
def test_use_jax_enables_jax_visualization():
42+
analysis = af.Analysis(use_jax=True)
4743
assert analysis._use_jax is True
48-
assert analysis._use_jax_for_visualization is True
4944
assert analysis.supports_jax_visualization is True
5045

5146

52-
@pytest.mark.skipif(_jax_installed(), reason="jax installed; happy path tested above")
53-
def test_use_jax_true_falls_back_to_numpy_when_jax_missing(recwarn):
54-
"""When jax isn't installed, use_jax=True should silently downgrade
55-
to use_jax=False after emitting a UserWarning. Affects 3.9/3.10
56-
where the [jax] extra is gated out."""
57-
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
47+
@pytest.mark.skipif(_jax_installed(), reason="jax installed")
48+
def test_use_jax_true_falls_back_when_jax_missing(recwarn):
49+
analysis = af.Analysis(use_jax=True)
5850
assert analysis._use_jax is False
59-
assert analysis._use_jax_for_visualization is False
6051
assert any("JAX is not installed" in str(w.message) for w in recwarn)
6152

6253

63-
def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
54+
def test_pyauto_disable_jax_env_var(monkeypatch):
6455
monkeypatch.setenv("PYAUTO_DISABLE_JAX", "1")
65-
analysis = af.Analysis(use_jax=True, use_jax_for_visualization=True)
56+
analysis = af.Analysis(use_jax=True)
6657
assert analysis._use_jax is False
67-
assert analysis._use_jax_for_visualization is False
6858

6959

70-
def test_fit_for_visualization_works_without_flag():
60+
def test_fit_for_visualization_delegates_to_fit_from():
7161
analysis = _FittableAnalysis()
7262
result = analysis.fit_for_visualization(instance="sentinel")
7363
assert result == ("fit", "sentinel")
7464
assert analysis.fit_from_calls == 1
75-
assert getattr(analysis, "_jitted_fit_from", None) is None
7665

7766

7867
def test_subclass_can_override_supports_jax_visualization():
@@ -82,5 +71,4 @@ def supports_jax_visualization(self):
8271
return True
8372

8473
analysis = ForcedAnalysis()
85-
assert analysis._use_jax_for_visualization is False
8674
assert analysis.supports_jax_visualization is True

0 commit comments

Comments
 (0)