Skip to content

Commit 8aa0397

Browse files
Jammy2211Jammy2211
authored andcommitted
feat: support vmapped Kaplinghat deflections
1 parent de5782c commit 8aa0397

2 files changed

Lines changed: 183 additions & 0 deletions

File tree

autogalaxy/profiles/mass/dark/kaplinghat.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,62 @@ def _nfw_mass_3d_within_radius_from(r, kappa_s, scale_radius):
6565
return 4.0 * np.pi * (kappa_s / scale_radius) * scale_radius**3 * mass_factor
6666

6767

68+
def _trapezoid_from(y, x, axis, xp):
69+
if hasattr(xp, "trapezoid"):
70+
return xp.trapezoid(y, x=x, axis=axis)
71+
return xp.trapz(y, x=x, axis=axis)
72+
73+
74+
def _interp_lane_emden_xp(x, xp):
75+
x_table, h_table, _ = _isothermal_lane_emden_table()
76+
return xp.interp(x, xp.asarray(x_table), xp.asarray(h_table))
77+
78+
79+
def _nfw_radial_deflection_from(r, kappa_s, scale_radius, xp):
80+
x = xp.maximum(r / scale_radius, 1.0e-8)
81+
82+
below = x < 1.0
83+
above = x > 1.0
84+
x_below = xp.minimum(x, 1.0 - 1.0e-8)
85+
x_above = xp.maximum(x, 1.0 + 1.0e-8)
86+
87+
f_below = xp.arccosh(1.0 / x_below) / xp.sqrt(1.0 - x_below**2)
88+
f_above = xp.arccos(1.0 / x_above) / xp.sqrt(x_above**2 - 1.0)
89+
f_at_one = 1.0
90+
f = xp.where(below, f_below, xp.where(above, f_above, f_at_one))
91+
92+
g = xp.log(x / 2.0) + f
93+
return 4.0 * kappa_s * scale_radius * g / x
94+
95+
96+
def _kaplinghat_density_3d_from_radius(
97+
radii,
98+
kappa_s,
99+
scale_radius,
100+
interaction_radius,
101+
central_density,
102+
isothermal_radius,
103+
xp,
104+
):
105+
x_nfw = xp.maximum(radii / scale_radius, 1.0e-12)
106+
nfw_density = (kappa_s / scale_radius) / (x_nfw * (1.0 + x_nfw) ** 2)
107+
108+
safe_isothermal_radius = xp.maximum(isothermal_radius, 1.0e-12)
109+
x_iso = xp.maximum(radii / safe_isothermal_radius, 1.0e-5)
110+
h = _interp_lane_emden_xp(x_iso, xp=xp)
111+
safe_central_density = xp.where(xp.isfinite(central_density), central_density, 0.0)
112+
iso_density = safe_central_density * xp.exp(-h)
113+
114+
use_iso = (
115+
(interaction_radius > 0.0)
116+
& (isothermal_radius > 0.0)
117+
& xp.isfinite(central_density)
118+
& (radii < interaction_radius)
119+
)
120+
121+
return xp.where(use_iso, iso_density, nfw_density)
122+
123+
68124
def _matched_isothermal_parameters_from(interaction_radius, kappa_s, scale_radius):
69125
rho_1 = _nfw_density_from(
70126
r=interaction_radius,
@@ -243,6 +299,57 @@ def radial_deflection_from_radius(self, radius):
243299
)[0]
244300
return 2.0 * mass_2d / radius
245301

302+
@staticmethod
303+
def radial_deflection_from(r, params, xp):
304+
kappa_s = params[0]
305+
scale_radius = params[1]
306+
interaction_radius = params[2]
307+
central_density = params[3]
308+
isothermal_radius = params[4]
309+
310+
r = xp.asarray(r)
311+
r_safe = xp.maximum(r, 1.0e-8)
312+
313+
z_max = xp.maximum(500.0 * scale_radius, 50.0 * interaction_radius)
314+
z_unit = xp.linspace(1.0e-5, 1.0, 160)
315+
z = z_max * z_unit**3
316+
u = xp.linspace(0.0, 1.0, 64)
317+
318+
projected_radii = xp.maximum(r_safe[:, None] * u[None, :], 1.0e-6)
319+
three_d_radii = xp.sqrt(
320+
projected_radii[:, :, None] ** 2 + z[None, None, :] ** 2
321+
)
322+
323+
density = _kaplinghat_density_3d_from_radius(
324+
radii=three_d_radii,
325+
kappa_s=kappa_s,
326+
scale_radius=scale_radius,
327+
interaction_radius=interaction_radius,
328+
central_density=central_density,
329+
isothermal_radius=isothermal_radius,
330+
xp=xp,
331+
)
332+
convergence = 2.0 * _trapezoid_from(density, x=z, axis=-1, xp=xp)
333+
334+
mass_integral = r_safe**2 * _trapezoid_from(
335+
convergence * u[None, :], x=u, axis=-1, xp=xp
336+
)
337+
numerical = 2.0 * mass_integral / r_safe
338+
analytic_nfw = _nfw_radial_deflection_from(
339+
r=r_safe,
340+
kappa_s=kappa_s,
341+
scale_radius=scale_radius,
342+
xp=xp,
343+
)
344+
345+
radial_deflection = xp.where(
346+
interaction_radius > 0.0,
347+
numerical,
348+
analytic_nfw,
349+
)
350+
351+
return xp.where(r > 1.0e-8, radial_deflection, 0.0)
352+
246353
@aa.decorators.to_vector_yx
247354
@aa.decorators.transform
248355
def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, xp=np, **kwargs):

test_autogalaxy/profiles/mass/dark/test_kaplinghat.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,82 @@ def test__lensing_quantities_are_finite_and_positive():
5959
assert deflections[0] == pytest.approx(np.array([0.0, 0.0]), abs=1.0e-8)
6060

6161

62+
def test__vmapped_deflections_match_instance_path_for_zero_interaction():
63+
profile = ag.mp.KaplinghatCoredNFWSph(
64+
centre=(0.1, -0.2),
65+
kappa_s=0.2,
66+
scale_radius=2.0,
67+
interaction_radius=0.0,
68+
)
69+
grid = np.array([[0.5, 0.2], [1.0, -0.2], [2.0, 1.0]])
70+
71+
params = np.array(
72+
[
73+
[
74+
profile.centre[0],
75+
profile.centre[1],
76+
profile.kappa_s,
77+
profile.scale_radius,
78+
profile.interaction_radius,
79+
profile.central_density,
80+
profile.isothermal_radius,
81+
]
82+
]
83+
)
84+
mask = np.array([True])
85+
86+
vmapped = ag.mp.KaplinghatCoredNFWSph.vmapped_deflections_from(
87+
grid=grid,
88+
params_batch=params,
89+
mask=mask,
90+
)
91+
92+
np.testing.assert_allclose(
93+
np.asarray(vmapped),
94+
profile.deflections_yx_2d_from(grid=ag.Grid2DIrregular(grid)).array,
95+
rtol=1.0e-6,
96+
atol=1.0e-8,
97+
)
98+
99+
100+
def test__vmapped_deflections_match_instance_path_for_sidm_core():
101+
profile = ag.mp.KaplinghatCoredNFWSph(
102+
centre=(0.1, -0.2),
103+
kappa_s=0.2,
104+
scale_radius=2.0,
105+
interaction_radius=0.5,
106+
)
107+
grid = np.array([[0.5, 0.2], [1.0, -0.2], [2.0, 1.0]])
108+
109+
params = np.array(
110+
[
111+
[
112+
profile.centre[0],
113+
profile.centre[1],
114+
profile.kappa_s,
115+
profile.scale_radius,
116+
profile.interaction_radius,
117+
profile.central_density,
118+
profile.isothermal_radius,
119+
]
120+
]
121+
)
122+
mask = np.array([True])
123+
124+
vmapped = ag.mp.KaplinghatCoredNFWSph.vmapped_deflections_from(
125+
grid=grid,
126+
params_batch=params,
127+
mask=mask,
128+
)
129+
130+
np.testing.assert_allclose(
131+
np.asarray(vmapped),
132+
profile.deflections_yx_2d_from(grid=ag.Grid2DIrregular(grid)).array,
133+
rtol=5.0e-2,
134+
atol=1.0e-3,
135+
)
136+
137+
62138
def test__mcr_constructor_reduces_to_nfw_when_interaction_is_zero():
63139
kaplinghat = ag.mp.KaplinghatCoredNFWMCRLudlowSph(
64140
centre=(1.0, 2.0),

0 commit comments

Comments
 (0)