You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
fix(jax): make jax usage truly optional across library
The Python version matrix CI was failing on 3.9 and 3.10 with
`ModuleNotFoundError: No module named 'jax'` raised from autofit
library code, even though the [jax] extra is gated `python_version
>= '3.11'`. Three eager / mis-polarized jax paths in autofit/ caused
this:
* `factor.py:_set_jacobians` did `import jax` at the top of the
function, but only the `jacfwd` / `jacobian` fallbacks need it.
The default `numerical_jacobian=True` path that every plain
`graph.Factor(...)` constructor hits never touches jax. Push
`import jax` down into the two branches that actually use it.
* `messages/normal.py:value_for` and `mapper/prior_model/array.py`
both routed plain Python `float` / `int` to `import jax.numpy`
(their `else` branch). After this fix the isinstance checks include
`(np.ndarray, np.float64, float, int, list)` so only genuine jax
values fall through to the jax branch.
Five unit tests in `test_autofit/` exercised jax-only behaviour
(`Fitness(use_jax_jit=True)`, `_JitFittableAnalysis` JIT dispatch,
`af.Array` returning a `jnp.ndarray`). They are moved out of the
unit suite to `autofit_workspace_test/scripts/graphical/jax_assertions.py`
in line with the policy that all library code paths must support
pure numpy.
Verified locally:
* `import autofit` succeeds in a fresh venv with no jax installed.
* `pytest test_autofit/` — 1225 passed, 1 skipped, 0 failed without
jax; 1245 passed, 0 failed with jax.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
0 commit comments