File tree Expand file tree Collapse file tree
test_autolens/point/model Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ try :
2+ import jax
3+
4+ JAX_INSTALLED = True
5+ except ImportError :
6+ JAX_INSTALLED = False
7+
18import numpy as np
29import pytest
310
@@ -14,23 +21,23 @@ def noise_map():
1421 return np .array ([1.0 , 1.0 ])
1522
1623
17- def test_andrew_implementation (
18- data ,
19- noise_map ,
20- ):
24+ @pytest .fixture
25+ def fit (data , noise_map ):
2126 model_positions = np .array (
2227 [
2328 (- 1.0749 , - 1.1 ),
2429 (1.19117 , 1.175 ),
2530 ]
2631 )
2732
28- fit = Fit (
33+ return Fit (
2934 data = data ,
3035 noise_map = noise_map ,
3136 model_positions = model_positions ,
3237 )
3338
39+
40+ def test_andrew_implementation (fit ):
3441 assert np .allclose (
3542 fit .all_permutations_log_likelihoods (),
3643 [
@@ -41,6 +48,11 @@ def test_andrew_implementation(
4148 assert fit .log_likelihood () == - 4.40375330990644
4249
4350
51+ @pytest .mark .skipif (not JAX_INSTALLED , reason = "JAX is not installed" )
52+ def test_jax (fit ):
53+ assert jax .jit (fit .log_likelihood )() == - 4.40375330990644
54+
55+
4456def test_nan_model_positions (
4557 data ,
4658 noise_map ,
You can’t perform that action at this time.
0 commit comments