Skip to content

Commit fb0dea5

Browse files
authored
Merge pull request #358 from PyAutoLabs/feature/lens-calc-hessian-richardson
fix: Richardson-extrapolate LensCalc numpy Hessian for JAX-path parity
2 parents 854e70e + eacdcd7 commit fb0dea5

2 files changed

Lines changed: 68 additions & 10 deletions

File tree

autogalaxy/operate/lens_calc.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,11 @@ def hessian_from(self, grid, xp=np) -> Tuple:
359359
360360
Two computational paths are available, selected via the `xp` parameter:
361361
362-
- **NumPy** (``xp=np``, default): finite-difference approximation. Deflection angles are
363-
evaluated at four shifted positions around each grid coordinate (±y, ±x) and the
364-
central difference is taken. JAX is not imported.
362+
- **NumPy** (``xp=np``, default): 2-point central finite-difference approximation,
363+
Richardson-extrapolated at step sizes ``h`` and ``h/2`` and combined as
364+
``(4 * H(h/2) - H(h)) / 3``. This cancels the leading ``O(h^2)`` truncation term,
365+
giving ``O(h^4)`` accuracy and matching the JAX path to float64 precision. JAX is
366+
not imported.
365367
366368
- **JAX** (``xp=jnp``): exact derivatives via ``jax.jacfwd`` applied to
367369
``deflections_yx_scalar``, vectorised over the grid with ``jnp.vectorize``.
@@ -377,9 +379,22 @@ def hessian_from(self, grid, xp=np) -> Tuple:
377379
used and the type of the returned arrays.
378380
"""
379381
if xp is np:
380-
return self._hessian_via_finite_difference(grid=grid)
382+
return self._hessian_via_richardson(grid=grid)
381383
return self._hessian_via_jax(grid=grid, xp=xp)
382384

385+
def _hessian_via_richardson(self, grid, buffer: float = 0.01) -> Tuple:
386+
yy_h, xy_h, yx_h, xx_h = self._hessian_via_finite_difference(
387+
grid=grid, buffer=buffer
388+
)
389+
yy_h2, xy_h2, yx_h2, xx_h2 = self._hessian_via_finite_difference(
390+
grid=grid, buffer=buffer / 2.0
391+
)
392+
hessian_yy = (4.0 * yy_h2 - yy_h) / 3.0
393+
hessian_xy = (4.0 * xy_h2 - xy_h) / 3.0
394+
hessian_yx = (4.0 * yx_h2 - yx_h) / 3.0
395+
hessian_xx = (4.0 * xx_h2 - xx_h) / 3.0
396+
return hessian_yy, hessian_xy, hessian_yx, hessian_xx
397+
383398
def _hessian_via_jax(self, grid, xp) -> Tuple:
384399
import jax
385400
import jax.numpy as jnp

test_autogalaxy/operate/test_deflections.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def test__hessian_from__diagonal_grid__correct_values():
102102
od = LensCalc.from_mass_obj(mp)
103103
hessian_yy, hessian_xy, hessian_yx, hessian_xx = od.hessian_from(grid=grid)
104104

105-
assert hessian_yy == pytest.approx(np.array([1.3883822, 0.694127]), 1.0e-4)
106-
assert hessian_xy == pytest.approx(np.array([-1.388124, -0.694094]), 1.0e-4)
107-
assert hessian_yx == pytest.approx(np.array([-1.388165, -0.694099]), 1.0e-4)
108-
assert hessian_xx == pytest.approx(np.array([1.3883824, 0.694127]), 1.0e-4)
105+
assert hessian_yy == pytest.approx(np.array([1.3882113, 0.6941056]), 1.0e-4)
106+
assert hessian_xy == pytest.approx(np.array([-1.3882113, -0.6941056]), 1.0e-4)
107+
assert hessian_yx == pytest.approx(np.array([-1.3882113, -0.6941056]), 1.0e-4)
108+
assert hessian_xx == pytest.approx(np.array([1.3882113, 0.6941056]), 1.0e-4)
109109

110110

111111
def test__hessian_from__axis_aligned_grid__correct_values():
@@ -152,8 +152,51 @@ def test__magnification_2d_via_hessian_from():
152152
od = LensCalc.from_mass_obj(mp)
153153
magnification = od.magnification_2d_via_hessian_from(grid=grid)
154154

155-
assert magnification.in_list[0] == pytest.approx(-0.56303, 1.0e-4)
156-
assert magnification.in_list[1] == pytest.approx(-2.57591, 1.0e-4)
155+
assert magnification.in_list[0] == pytest.approx(-0.5629291, 1.0e-4)
156+
assert magnification.in_list[1] == pytest.approx(-2.575917, 1.0e-4)
157+
158+
159+
def test__hessian_from__np_richardson_matches_jax_jacfwd_to_float64():
160+
import jax.numpy as jnp
161+
162+
grid = ag.Grid2DIrregular(values=[(0.5, 0.5), (1.0, 1.0), (0.7, -0.3)])
163+
164+
mp = ag.mp.Isothermal(
165+
centre=(0.0, 0.0), ell_comps=(0.05, -0.111111), einstein_radius=1.5
166+
)
167+
168+
od = LensCalc.from_mass_obj(mp)
169+
170+
np_hess = od.hessian_from(grid=grid, xp=np)
171+
jnp_hess = od.hessian_from(grid=grid, xp=jnp)
172+
173+
for np_component, jnp_component in zip(np_hess, jnp_hess):
174+
np.testing.assert_allclose(
175+
np.asarray(np_component),
176+
np.asarray(jnp_component),
177+
rtol=1.0e-8,
178+
)
179+
180+
181+
def test__magnification_2d_via_hessian_from__np_jnp_agree_to_float64():
182+
import jax.numpy as jnp
183+
184+
grid = ag.Grid2DIrregular(values=[(0.5, 0.5), (1.0, 1.0), (0.7, -0.3)])
185+
186+
mp = ag.mp.Isothermal(
187+
centre=(0.0, 0.0), ell_comps=(0.05, -0.111111), einstein_radius=1.5
188+
)
189+
190+
od = LensCalc.from_mass_obj(mp)
191+
192+
np_mag = od.magnification_2d_via_hessian_from(grid=grid, xp=np)
193+
jnp_mag = od.magnification_2d_via_hessian_from(grid=grid, xp=jnp)
194+
195+
np.testing.assert_allclose(
196+
np.asarray(np_mag.array),
197+
np.asarray(jnp_mag),
198+
rtol=1.0e-7,
199+
)
157200

158201

159202
def test__tangential_critical_curve_list_from__radius_matches_einstein_radius():

0 commit comments

Comments
 (0)