Skip to content

Commit 9558988

Browse files
Jammy2211claude
authored andcommitted
fix(jax): make jax usage truly optional across library
The Python version matrix CI was failing on 3.9 and 3.10 with `ModuleNotFoundError: No module named 'jax'` raised from autofit library code, even though the [jax] extra is gated `python_version >= '3.11'`. Three eager / mis-polarized jax paths in autofit/ caused this: * `factor.py:_set_jacobians` did `import jax` at the top of the function, but only the `jacfwd` / `jacobian` fallbacks need it. The default `numerical_jacobian=True` path that every plain `graph.Factor(...)` constructor hits never touches jax. Push `import jax` down into the two branches that actually use it. * `messages/normal.py:value_for` and `mapper/prior_model/array.py` both routed plain Python `float` / `int` to `import jax.numpy` (their `else` branch). After this fix the isinstance checks include `(np.ndarray, np.float64, float, int, list)` so only genuine jax values fall through to the jax branch. Five unit tests in `test_autofit/` exercised jax-only behaviour (`Fitness(use_jax_jit=True)`, `_JitFittableAnalysis` JIT dispatch, `af.Array` returning a `jnp.ndarray`). They are moved out of the unit suite to `autofit_workspace_test/scripts/graphical/jax_assertions.py` in line with the policy that all library code paths must support pure numpy. Verified locally: * `import autofit` succeeds in a fresh venv with no jax installed. * `pytest test_autofit/` — 1225 passed, 1 skipped, 0 failed without jax; 1245 passed, 0 failed with jax. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 3ef0d46 commit 9558988

6 files changed

Lines changed: 5 additions & 115 deletions

File tree

autofit/graphical/factor_graphs/factor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,6 @@ def _set_jacobians(
284284
numerical_jacobian=True,
285285
jacfwd=True,
286286
):
287-
import jax
288-
289287
self._vjp = vjp
290288
self._jacfwd = jacfwd
291289
if vjp or factor_vjp:
@@ -302,8 +300,10 @@ def _set_jacobians(
302300
elif numerical_jacobian:
303301
self._factor_jacobian = self._numerical_factor_jacobian
304302
elif jacfwd:
303+
import jax
305304
self._jacobian = jax.jacfwd(self._factor, range(self.n_args))
306305
else:
306+
import jax
307307
self._jacobian = jax.jacobian(self._factor, range(self.n_args))
308308

309309
def _factor_value(self, raw_fval) -> FactorValue:

autofit/mapper/prior_model/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,15 +87,15 @@ def _instance_for_arguments(
8787
pass
8888

8989
if make_array:
90-
if isinstance(value, np.ndarray) or isinstance(value, np.float64):
90+
if isinstance(value, (np.ndarray, np.float64, float, int)):
9191
array = np.zeros(self.shape)
9292
make_array = False
9393
else:
9494
import jax.numpy as jnp
9595
array = jnp.zeros(self.shape)
9696
make_array = False
9797

98-
if isinstance(value, np.ndarray) or isinstance(value, np.float64):
98+
if isinstance(value, (np.ndarray, np.float64, float, int)):
9999
array[index] = value
100100
else:
101101
array = array.at[index].set(value)

autofit/messages/normal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def value_for(self, unit: float) -> float:
417417
>>> prior = af.GaussianPrior(mean=1.0, sigma=2.0)
418418
>>> physical_value = prior.value_for(unit=0.5)
419419
"""
420-
if isinstance(unit, np.ndarray) or isinstance(unit, np.float64):
420+
if isinstance(unit, (np.ndarray, np.float64, float, int, list)):
421421
from scipy.special import erfinv as scipy_erfinv
422422
inv = scipy_erfinv(1 - 2.0 * (1.0 - unit))
423423
else:

test_autofit/analysis/test_use_jax_for_visualization.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,6 @@ def fit_from(self, instance):
2020
return ("fit", instance)
2121

2222

23-
class _JitFittableAnalysis(af.Analysis):
24-
"""Analysis with a ``fit_from`` returning a JIT-traceable array.
25-
26-
``_FittableAnalysis.fit_from`` returns a Python tuple with a string literal,
27-
which is not tracer-compatible. For the JIT-enabled dispatch path we need a
28-
``fit_from`` whose output is entirely JAX-compatible.
29-
"""
30-
31-
def __init__(self, **kwargs):
32-
super().__init__(**kwargs)
33-
self.fit_from_calls = 0
34-
35-
def log_likelihood_function(self, instance):
36-
return 0.0
37-
38-
def fit_from(self, instance):
39-
import jax.numpy as jnp
40-
41-
self.fit_from_calls += 1
42-
return jnp.asarray(instance) * 2.0
43-
44-
4523
def test_default_flag_is_false():
4624
analysis = af.Analysis()
4725
assert analysis._use_jax is False
@@ -70,23 +48,6 @@ def test_pyauto_disable_jax_env_var_clears_both_flags(monkeypatch):
7048
assert analysis._use_jax_for_visualization is False
7149

7250

73-
def test_fit_for_visualization_dispatches_through_jit_when_flag_set():
74-
import jax.numpy as jnp
75-
76-
analysis = _JitFittableAnalysis(use_jax=True, use_jax_for_visualization=True)
77-
78-
assert getattr(analysis, "_jitted_fit_from", None) is None
79-
80-
result_1 = analysis.fit_for_visualization(instance=1.0)
81-
assert analysis._jitted_fit_from is not None
82-
assert jnp.allclose(result_1, jnp.asarray(2.0))
83-
84-
jitted_after_first = analysis._jitted_fit_from
85-
result_2 = analysis.fit_for_visualization(instance=3.0)
86-
assert analysis._jitted_fit_from is jitted_after_first
87-
assert jnp.allclose(result_2, jnp.asarray(6.0))
88-
89-
9051
def test_fit_for_visualization_works_without_flag():
9152
analysis = _FittableAnalysis()
9253
result = analysis.fit_for_visualization(instance="sentinel")

test_autofit/mapper/test_array.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -159,41 +159,3 @@ def test_tree_flatten(array):
159159
).all()
160160

161161

162-
class Analysis(af.Analysis):
163-
def log_likelihood_function(self, instance):
164-
return -float(
165-
np.mean(
166-
(
167-
np.array(
168-
[
169-
[0.1, 0.2],
170-
[0.3, 0.4],
171-
]
172-
)
173-
- instance
174-
)
175-
** 2
176-
)
177-
)
178-
179-
180-
def test_optimisation():
181-
182-
import jax.numpy as jnp
183-
184-
array = af.Array(
185-
shape=(2, 2),
186-
prior=af.UniformPrior(
187-
lower_limit=0.0,
188-
upper_limit=1.0,
189-
),
190-
)
191-
result = af.DynestyStatic().fit(model=array, analysis=Analysis())
192-
193-
posterior = result.model
194-
array[0, 0] = posterior[0, 0]
195-
array[0, 1] = posterior[0, 1]
196-
197-
result = af.DynestyStatic().fit(model=array, analysis=Analysis())
198-
199-
assert isinstance(result.instance, jnp.ndarray)

test_autofit/non_linear/test_fitness_jax_dispatch.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,39 +23,6 @@ def test_default_dispatch_is_call():
2323
assert fitness.use_jax_vmap is False
2424

2525

26-
def test_jit_dispatch_sets_call_to_jit():
27-
fitness = _make_fitness(use_jax_jit=True)
28-
assert fitness.use_jax_jit is True
29-
assert fitness._call is fitness._jit
30-
31-
32-
def test_vmap_takes_precedence_over_jit():
33-
fitness = _make_fitness(use_jax_jit=True, use_jax_vmap=True)
34-
assert fitness._call is fitness._vmap
35-
36-
37-
def test_pickle_strips_jax_cached_attrs():
38-
"""
39-
Dynesty's checkpoint writes pickle the loglikelihood. JAX-compiled
40-
callables (jax.jit / jax.vmap / jax.grad) carry C++ XLA state that
41-
cannot roundtrip through pickle. ``Fitness.__getstate__`` must drop
42-
them; ``Fitness.__setstate__`` re-derives the dispatch on resume.
43-
"""
44-
fitness = _make_fitness(use_jax_jit=True)
45-
46-
state = fitness.__getstate__()
47-
assert "_call" not in state
48-
assert "_jit" not in state
49-
assert "_vmap" not in state
50-
assert "_grad" not in state
51-
52-
blob = pickle.dumps(fitness)
53-
restored = pickle.loads(blob)
54-
55-
assert restored.use_jax_jit is True
56-
assert restored._call is restored._jit
57-
58-
5926
def test_pickle_default_path_unchanged():
6027
fitness = _make_fitness()
6128

0 commit comments

Comments
 (0)