Skip to content

Commit 568c3eb

Browse files
Jammy2211claude
authored andcommitted
feat(analysis): fall back to numpy when use_jax=True but jax not installed
Workspace example scripts pass use_jax=True by default, which crashes on Python 3.9/3.10 (where JAX isn't installed under the [jax] extra gate). Detect missing jax via importlib.util.find_spec, emit a loud banner-style UserWarning, and silently downgrade to use_jax=False. The fit still runs (just slower). Resolves Category B from PyAutoBuild#74 (matrix run smoke fails on 3.9/3.10 across autogalaxy_workspace and autolens_workspace). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9512e44 commit 568c3eb

1 file changed

Lines changed: 24 additions & 0 deletions

File tree

autofit/non_linear/analysis/analysis.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,30 @@ def __init__(
4444
use_jax = False
4545
use_jax_for_visualization = False
4646

47+
# If the user requested JAX but it isn't installed (e.g. Python <3.11
48+
# without the [jax] extra), fall back to numpy with a loud warning
49+
# rather than crashing later when the analysis tries to jit-compile.
50+
if use_jax:
51+
import importlib.util
52+
import warnings
53+
if importlib.util.find_spec("jax") is None:
54+
warnings.warn(
55+
"\n"
56+
"+----------------------------------------------------------------------+\n"
57+
"| use_jax=True was requested but JAX is not installed. |\n"
58+
"| |\n"
59+
"| Falling back to numpy. The fit will run, but JAX acceleration |\n"
60+
"| (typically 10-100x for large lens models) is unavailable. |\n"
61+
"| |\n"
62+
"| To enable JAX, install on Python 3.11+ via your library's [jax] |\n"
63+
"| extra, e.g.: pip install autolens[jax] |\n"
64+
"+----------------------------------------------------------------------+",
65+
UserWarning,
66+
stacklevel=2,
67+
)
68+
use_jax = False
69+
use_jax_for_visualization = False
70+
4771
if use_jax_for_visualization and not use_jax:
4872
logger.warning(
4973
"use_jax_for_visualization=True requires use_jax=True; "

0 commit comments

Comments
 (0)