Skip to content

Commit 3ef0d46

Browse files
Jammy2211claude
authored andcommitted
test: remove jax from conftest, gate jax test files with importorskip
Per the 'library unit tests stay numpy-only' rule, conftest must not import jax at all — not even via try/except. Changes: - conftest.py: remove `import jax`, drop the recreate fixture (moved into test_pytrees.py since that's the only consumer), add a find_spec-gated collect_ignore_glob that skips test_autofit/jax/ and test_jacobians.py when jax isn't installed. - jax/test_pytrees.py: move pytest.importorskip("jax") BEFORE the jax imports (was AFTER, so module load failed on no-jax). Inline the recreate fixture here. - graphical/functionality/test_jacobians.py: add pytest.importorskip at module top so the file is self-contained even if conftest's collect_ignore is bypassed. This supersedes 1692c8e (try/except hack). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 1692c8e commit 3ef0d46

3 files changed

Lines changed: 34 additions & 19 deletions

File tree

test_autofit/conftest.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
try:
2-
import jax
3-
except ImportError:
4-
jax = None
1+
import importlib.util
52
import multiprocessing
63
import os
74
import shutil
@@ -19,23 +16,22 @@
1916
from autofit.database.model import sa
2017
from autofit.non_linear.search import abstract_search
2118

19+
# Skip JAX-only tests when jax isn't installed. find_spec checks availability
20+
# WITHOUT importing jax, so this conftest stays numpy-only per the
21+
# "library unit tests stay numpy-only" rule.
22+
collect_ignore_glob = []
23+
if importlib.util.find_spec("jax") is None:
24+
collect_ignore_glob = [
25+
"jax/*.py",
26+
"graphical/functionality/test_jacobians.py",
27+
]
28+
2229
if sys.platform == "darwin":
2330
multiprocessing.set_start_method("fork")
2431

2532
directory = Path(__file__).parent
2633

2734

28-
@pytest.fixture(name="recreate")
29-
def recreate():
30-
31-
def _recreate(o):
32-
flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)]
33-
children, aux_data = flatten_func(o)
34-
return unflatten_func(aux_data, children)
35-
36-
return _recreate
37-
38-
3935
@pytest.fixture(autouse=True)
4036
def turn_off_gc(monkeypatch):
4137
monkeypatch.setattr(abstract_search, "gc", MagicMock())

test_autofit/graphical/functionality/test_jacobians.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from itertools import combinations
2-
import jax
32
import numpy as np
43
import pytest
54

5+
# Belt-and-suspenders: the top-level conftest also ignores this file when
6+
# jax isn't installed; importorskip keeps the file self-contained.
7+
jax = pytest.importorskip("jax")
8+
69
from autofit.mapper.variable import variables
710
from autofit.graphical.factor_graphs import (
811
Factor,

test_autofit/jax/test_pytrees.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
import numpy as np
22
import pytest
3-
import jax.numpy as jnp
3+
4+
# importorskip MUST run before any jax import — skips the whole module
5+
# cleanly when jax isn't installed (Python <3.11 in our [jax] extra gate).
6+
jax = pytest.importorskip("jax")
7+
jnp = pytest.importorskip("jax.numpy")
48
from jax.tree_util import register_pytree_node_class
59

610
import autofit as af
711
from autofit import UniformPrior
812

9-
jax = pytest.importorskip("jax")
10-
1113
UniformPrior = register_pytree_node_class(UniformPrior)
1214
GaussianPrior = register_pytree_node_class(af.GaussianPrior)
1315
TruncatedGaussianPrior = register_pytree_node_class(af.TruncatedGaussianPrior)
1416
Collection = register_pytree_node_class(af.Collection)
1517
Model = register_pytree_node_class(af.Model)
1618
ModelInstance = register_pytree_node_class(af.ModelInstance)
1719

20+
21+
@pytest.fixture(name="recreate")
22+
def make_recreate():
23+
"""jax-pytree-roundtrip fixture, scoped to this test module since it's
24+
the only consumer and it must not pollute the numpy-only top-level
25+
conftest."""
26+
27+
def _recreate(o):
28+
flatten_func, unflatten_func = jax._src.tree_util._registry[type(o)]
29+
children, aux_data = flatten_func(o)
30+
return unflatten_func(aux_data, children)
31+
32+
return _recreate
33+
1834
@pytest.fixture(name="gaussian")
1935
def make_gaussian():
2036
return af.ex.Gaussian(centre=1.0, sigma=1.0, normalization=1.0)

0 commit comments

Comments
 (0)