@@ -102,10 +102,10 @@ def test__hessian_from__diagonal_grid__correct_values():
102102 od = LensCalc .from_mass_obj (mp )
103103 hessian_yy , hessian_xy , hessian_yx , hessian_xx = od .hessian_from (grid = grid )
104104
105- assert hessian_yy == pytest .approx (np .array ([1.3883822 , 0.694127 ]), 1.0e-4 )
106- assert hessian_xy == pytest .approx (np .array ([- 1.388124 , - 0.694094 ]), 1.0e-4 )
107- assert hessian_yx == pytest .approx (np .array ([- 1.388165 , - 0.694099 ]), 1.0e-4 )
108- assert hessian_xx == pytest .approx (np .array ([1.3883824 , 0.694127 ]), 1.0e-4 )
105+ assert hessian_yy == pytest .approx (np .array ([1.3882113 , 0.6941056 ]), 1.0e-4 )
106+ assert hessian_xy == pytest .approx (np .array ([- 1.3882113 , - 0.6941056 ]), 1.0e-4 )
107+ assert hessian_yx == pytest .approx (np .array ([- 1.3882113 , - 0.6941056 ]), 1.0e-4 )
108+ assert hessian_xx == pytest .approx (np .array ([1.3882113 , 0.6941056 ]), 1.0e-4 )
109109
110110
111111def test__hessian_from__axis_aligned_grid__correct_values ():
@@ -152,8 +152,51 @@ def test__magnification_2d_via_hessian_from():
152152 od = LensCalc .from_mass_obj (mp )
153153 magnification = od .magnification_2d_via_hessian_from (grid = grid )
154154
155- assert magnification .in_list [0 ] == pytest .approx (- 0.56303 , 1.0e-4 )
156- assert magnification .in_list [1 ] == pytest .approx (- 2.57591 , 1.0e-4 )
155+ assert magnification .in_list [0 ] == pytest .approx (- 0.5629291 , 1.0e-4 )
156+ assert magnification .in_list [1 ] == pytest .approx (- 2.575917 , 1.0e-4 )
157+
158+
159+ def test__hessian_from__np_richardson_matches_jax_jacfwd_to_float64 ():
160+ import jax .numpy as jnp
161+
162+ grid = ag .Grid2DIrregular (values = [(0.5 , 0.5 ), (1.0 , 1.0 ), (0.7 , - 0.3 )])
163+
164+ mp = ag .mp .Isothermal (
165+ centre = (0.0 , 0.0 ), ell_comps = (0.05 , - 0.111111 ), einstein_radius = 1.5
166+ )
167+
168+ od = LensCalc .from_mass_obj (mp )
169+
170+ np_hess = od .hessian_from (grid = grid , xp = np )
171+ jnp_hess = od .hessian_from (grid = grid , xp = jnp )
172+
173+ for np_component , jnp_component in zip (np_hess , jnp_hess ):
174+ np .testing .assert_allclose (
175+ np .asarray (np_component ),
176+ np .asarray (jnp_component ),
177+ rtol = 1.0e-8 ,
178+ )
179+
180+
181+ def test__magnification_2d_via_hessian_from__np_jnp_agree_to_float64 ():
182+ import jax .numpy as jnp
183+
184+ grid = ag .Grid2DIrregular (values = [(0.5 , 0.5 ), (1.0 , 1.0 ), (0.7 , - 0.3 )])
185+
186+ mp = ag .mp .Isothermal (
187+ centre = (0.0 , 0.0 ), ell_comps = (0.05 , - 0.111111 ), einstein_radius = 1.5
188+ )
189+
190+ od = LensCalc .from_mass_obj (mp )
191+
192+ np_mag = od .magnification_2d_via_hessian_from (grid = grid , xp = np )
193+ jnp_mag = od .magnification_2d_via_hessian_from (grid = grid , xp = jnp )
194+
195+ np .testing .assert_allclose (
196+ np .asarray (np_mag .array ),
197+ np .asarray (jnp_mag ),
198+ rtol = 1.0e-7 ,
199+ )
157200
158201
159202def test__tangential_critical_curve_list_from__radius_matches_einstein_radius ():
0 commit comments