Skip to content

Commit 9f74b31

Browse files
committed
Enable JAX 64-bit mode for tests
Add jax.config.update("jax_enable_x64", True) at module level in conftest.py so all tests run with float64 precision. This fixes the pre-existing failure in test__curvature_matrix_via_psf_weighted_noise_two_methods_agree where float32 rounding produced a max absolute error of ~0.008, exceeding the 1e-4 tolerance. https://claude.ai/code/session_01B9sVEV54XWCa2LJw1C8gvv
1 parent e99a05c commit 9f74b31

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

test_autoarray/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import jax
12
import jax.numpy as jnp
23

4+
jax.config.update("jax_enable_x64", True)
5+
36

47
def pytest_configure():
58
_ = jnp.sum(jnp.array([0.0])) # Force backend init

0 commit comments

Comments
 (0)