Skip to content

Commit 24bfb63

Browse files
Jammy2211claude
authored andcommitted
feat: replace Ludlow16 colossus pure_callback with JAX-native impl
Phase 2 swap-in for #403. The previous mcr_util.py wrapped a colossus call in jax.pure_callback to compute the Ludlow et al. 2016 mass-concentration relation. The callback couldn't be JIT-traced or differentiated and pinned colossus as a hard runtime dependency. This PR: - Adds autogalaxy/profiles/mass/dark/ludlow16.py — a JAX-native port of colossus.halo.concentration.modelLudlow16 (~400 lines: EH98 transfer + Heath '77 growth factor + Einasto gammainc + 200-point Ludlow c-solver, all xp-aware so the same function runs under both numpy and JAX). - Replaces _ludlow16_cosmology_callback and ludlow16_cosmology_jax in mcr_util.py with a single ludlow16_cosmology(..., xp=np). Removes the jax.pure_callback and the colossus imports. - Drops the if-xp-is-np branching in kappa_s_and_scale_radius_for_ludlow and kappa_s_scale_radius_and_core_radius_for_ludlow. - Moves colossus from required to test/dev extras in pyproject.toml. - Adds 10-test unit suite test_ludlow16.py (numpy-path regression vs colossus, skipped if colossus unavailable). Numbers (matching the Phase 1 prototype, PR #402): c200c max rel error vs colossus : 7.5e-4 end-to-end kappa_s max rel error : 1.07e-3 end-to-end conv/defl per-pixel max : 8.21e-4 Intrinsic Ludlow16 scatter (~0.13 dex = ~35%) is ~350x larger. Cross-implementation verification: re-running autolens_workspace_test subhalo.py against this PR's code (Scenarios C and D, the regression literals locked in via PR #92 from the colossus path) gives vmap = -1.349200e+09 — bit-identical match to rtol=1e-4. The JAX-native implementation reproduces the colossus pure_callback's downstream log-likelihood exactly at the precision we care about. Test tolerance updates: four existing NFW-MCR test files had pytest.approx tolerances of 1.0e-4 on scale_radius / truncation_radius literals. These were implicitly claiming colossus-level precision (the literals were generated with colossus). The JAX implementation differs from colossus by ~2e-4 at these points (sub-Ludlow-scatter), so the tolerances are loosened to 1.0e-3 (still 0.1%, well below the intrinsic Ludlow scatter the scatter_sigma parameter marginalises over). Affected: test_nfw_mcr.py, test_nfw_scatter.py, test_nfw_truncated_mcr.py, test_nfw_truncated_mcr_scatter.py. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent ae7ebd0 commit 24bfb63

8 files changed

Lines changed: 724 additions & 306 deletions

File tree

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
"""
2+
JAX-native Ludlow et al. 2016 mass-concentration relation.
3+
4+
Replaces the previous ``jax.pure_callback``-wrapped call to
5+
``colossus.halo.concentration`` (formerly in ``mcr_util.py``). The
6+
algorithm follows colossus' ``modelLudlow16`` (``concentration.py``
7+
lines 1104-1192) and ``modelEisenstein98`` (``power_spectrum.py``
8+
lines 476-608) line-for-line, but every operation is done in pure
9+
``xp.*`` arithmetic so the same function runs under both numpy
10+
(``xp=np``) and JAX (``xp=jnp``).
11+
12+
Validated in PR #402 (Phase 1 feasibility): max relative error in
13+
c200c vs colossus over the lensing parameter grid
14+
(log M ∈ [10, 14] Msun/h, z ∈ [0.1, 2.5]) is **7.5 × 10⁻⁴**, with
15+
end-to-end downstream errors in convergence/deflection of ≤ 8 × 10⁻⁴.
16+
The 0.13 dex intrinsic Ludlow16 scatter is ~350× larger.
17+
18+
Units (matching colossus throughout):
19+
M200c is in Msun / h
20+
R is in Mpc / h
21+
z is dimensionless redshift
22+
"""
23+
24+
import numpy as np
25+
26+
27+
# Colossus' 'planck15' preset — the cosmology the previous callback used
28+
# internally to compute concentration. autogalaxy's own Planck15 (Om0=0.3075)
29+
# is used for everything else in mcr_util.ludlow16_cosmology; this constant is
30+
# only for the concentration call. Keeping the split is an apples-to-apples
31+
# swap with the previous behaviour; unifying the two is a separate decision.
32+
PLANCK15_COSMOLOGY = dict(
33+
h=0.6774,
34+
Om0=0.3089,
35+
Ob0=0.0486,
36+
Tcmb0=2.7255,
37+
sigma8=0.8159,
38+
ns=0.9667,
39+
)
40+
41+
42+
def _gammainc(a, x, xp):
43+
if xp is np:
44+
from scipy.special import gammainc
45+
else:
46+
from jax.scipy.special import gammainc
47+
return gammainc(a, x)
48+
49+
50+
def _erfc(arg, xp):
51+
if xp is np:
52+
from scipy.special import erfc
53+
else:
54+
from jax.scipy.special import erfc
55+
return erfc(arg)
56+
57+
58+
# ---------------------------------------------------------------------------
59+
# Eisenstein & Hu 1998 transfer function — direct port of
60+
# colossus.cosmology.power_spectrum.modelEisenstein98.
61+
# ---------------------------------------------------------------------------
62+
63+
64+
def transfer_eh98(k, h, Om0, Ob0, Tcmb0, xp=np):
65+
"""EH98 transfer function T(k) including baryon acoustic features."""
66+
omc = Om0 - Ob0
67+
ombom0 = Ob0 / Om0
68+
h2 = h ** 2
69+
om0h2 = Om0 * h2
70+
ombh2 = Ob0 * h2
71+
theta2p7 = Tcmb0 / 2.7
72+
theta2p72 = theta2p7 ** 2
73+
theta2p74 = theta2p72 ** 2
74+
75+
kh = k * h
76+
77+
zeq = 2.50e4 * om0h2 / theta2p74
78+
keq = 7.46e-2 * om0h2 / theta2p72
79+
80+
b1d = 0.313 * om0h2 ** -0.419 * (1.0 + 0.607 * om0h2 ** 0.674)
81+
b2d = 0.238 * om0h2 ** 0.223
82+
zd = 1291.0 * om0h2 ** 0.251 / (1.0 + 0.659 * om0h2 ** 0.828) * (
83+
1.0 + b1d * ombh2 ** b2d
84+
)
85+
86+
Rd = 31.5 * ombh2 / theta2p74 / (zd / 1e3)
87+
Req = 31.5 * ombh2 / theta2p74 / (zeq / 1e3)
88+
89+
s = (
90+
2.0
91+
/ 3.0
92+
/ keq
93+
* xp.sqrt(6.0 / Req)
94+
* xp.log((xp.sqrt(1.0 + Rd) + xp.sqrt(Rd + Req)) / (1.0 + xp.sqrt(Req)))
95+
)
96+
97+
ksilk = 1.6 * ombh2 ** 0.52 * om0h2 ** 0.73 * (1.0 + (10.4 * om0h2) ** -0.95)
98+
99+
q = kh / 13.41 / keq
100+
101+
a1 = (46.9 * om0h2) ** 0.670 * (1.0 + (32.1 * om0h2) ** -0.532)
102+
a2 = (12.0 * om0h2) ** 0.424 * (1.0 + (45.0 * om0h2) ** -0.582)
103+
ac = a1 ** (-ombom0) * a2 ** (-(ombom0 ** 3))
104+
105+
b1 = 0.944 / (1.0 + (458.0 * om0h2) ** -0.708)
106+
b2 = (0.395 * om0h2) ** -0.0266
107+
bc = 1.0 / (1.0 + b1 * ((omc / Om0) ** b2 - 1.0))
108+
109+
y = (1.0 + zeq) / (1.0 + zd)
110+
Gy = y * (
111+
-6.0 * xp.sqrt(1.0 + y)
112+
+ (2.0 + 3.0 * y)
113+
* xp.log((xp.sqrt(1.0 + y) + 1.0) / (xp.sqrt(1.0 + y) - 1.0))
114+
)
115+
116+
ab = 2.07 * keq * s * (1.0 + Rd) ** (-3.0 / 4.0) * Gy
117+
118+
f = 1.0 / (1.0 + (kh * s / 5.4) ** 4)
119+
120+
C = 14.2 / ac + 386.0 / (1.0 + 69.9 * q ** 1.08)
121+
T0t = xp.log(xp.e + 1.8 * bc * q) / (xp.log(xp.e + 1.8 * bc * q) + C * q * q)
122+
123+
C1bc = 14.2 + 386.0 / (1.0 + 69.9 * q ** 1.08)
124+
T0t1bc = xp.log(xp.e + 1.8 * bc * q) / (
125+
xp.log(xp.e + 1.8 * bc * q) + C1bc * q * q
126+
)
127+
Tc = f * T0t1bc + (1.0 - f) * T0t
128+
129+
bb = (
130+
0.5
131+
+ ombom0
132+
+ (3.0 - 2.0 * ombom0)
133+
* xp.sqrt((17.2 * om0h2) * (17.2 * om0h2) + 1.0)
134+
)
135+
136+
bnode = 8.41 * om0h2 ** 0.435
137+
138+
st = s / (1.0 + (bnode / kh / s) * (bnode / kh / s) * (bnode / kh / s)) ** (
139+
1.0 / 3.0
140+
)
141+
142+
C11 = 14.2 + 386.0 / (1.0 + 69.9 * q ** 1.08)
143+
T0t11 = xp.log(xp.e + 1.8 * q) / (xp.log(xp.e + 1.8 * q) + C11 * q * q)
144+
Tb = (
145+
T0t11 / (1.0 + (kh * s / 5.2) ** 2)
146+
+ ab / (1.0 + (bb / kh / s) ** 3) * xp.exp(-((kh / ksilk) ** 1.4))
147+
) * xp.sin(kh * st) / (kh * st)
148+
149+
return ombom0 * Tb + omc / Om0 * Tc
150+
151+
152+
# ---------------------------------------------------------------------------
153+
# sigma(R, z=0): RMS of mass within top-hat radius R, normalised to sigma8.
154+
# ---------------------------------------------------------------------------
155+
156+
157+
def _tophat_window(x, xp=np):
158+
"""W(x) = 3 (sin x - x cos x) / x^3, with a safe small-x expansion."""
159+
small = x < 1.0e-3
160+
x2 = x * x
161+
safe_small = 1.0 - x2 / 10.0 + x2 * x2 / 280.0
162+
safe_large = 3.0 * (xp.sin(x) - x * xp.cos(x)) / xp.where(small, 1.0, x ** 3)
163+
return xp.where(small, safe_small, safe_large)
164+
165+
166+
def _sigma2_unnormalised(
167+
R, h, Om0, Ob0, Tcmb0, ns,
168+
xp=np,
169+
k_log_min=-5.0, k_log_max=3.0, nk=256,
170+
):
171+
"""sigma^2(R) at z=0 for an unnormalised power spectrum P(k) = k^ns T(k)^2."""
172+
ln_k = xp.linspace(
173+
k_log_min * xp.log(xp.asarray(10.0)),
174+
k_log_max * xp.log(xp.asarray(10.0)),
175+
nk,
176+
)
177+
k = xp.exp(ln_k)
178+
179+
Tk = transfer_eh98(k, h, Om0, Ob0, Tcmb0, xp=xp)
180+
Pk_unnorm = k ** ns * Tk ** 2
181+
182+
R = xp.atleast_1d(R)
183+
kR = k[None, :] * R[:, None]
184+
W = _tophat_window(kR, xp=xp)
185+
186+
integrand = k[None, :] ** 3 * Pk_unnorm[None, :] * W ** 2
187+
integrand = integrand / (2.0 * xp.pi ** 2)
188+
189+
sigma2 = xp.trapezoid(integrand, ln_k, axis=-1)
190+
if sigma2.shape == (1,):
191+
return sigma2[0]
192+
return sigma2
193+
194+
195+
def sigma_R(
196+
R, h, Om0, Ob0, Tcmb0, sigma8, ns,
197+
xp=np,
198+
k_log_min=-5.0, k_log_max=3.0, nk=256,
199+
):
200+
"""sigma(R, z=0), normalised so sigma(R=8 Mpc/h) = sigma8."""
201+
sigma2_unnorm = _sigma2_unnormalised(
202+
R, h, Om0, Ob0, Tcmb0, ns,
203+
xp=xp, k_log_min=k_log_min, k_log_max=k_log_max, nk=nk,
204+
)
205+
sigma2_8_unnorm = _sigma2_unnormalised(
206+
xp.asarray(8.0), h, Om0, Ob0, Tcmb0, ns,
207+
xp=xp, k_log_min=k_log_min, k_log_max=k_log_max, nk=nk,
208+
)
209+
norm = sigma8 ** 2 / sigma2_8_unnorm
210+
return xp.sqrt(norm * sigma2_unnorm)
211+
212+
213+
# ---------------------------------------------------------------------------
214+
# Linear growth factor D(z), flat LCDM (no relspecies).
215+
# Integral form (Eisenstein & Hu 1999 Eq. 8 / Heath 1977).
216+
# ---------------------------------------------------------------------------
217+
218+
219+
def _E_lcdm(z, Om0, Ode0, xp=np):
220+
return xp.sqrt(Om0 * (1.0 + z) ** 3 + Ode0)
221+
222+
223+
def _growth_unnormalised(z, Om0, Ode0, xp=np, nz=256, z_max=1.0e4):
224+
"""D_+(z) un-normalised. Integrate (1+z') / E(z')^3 from z to z_max via u=ln(1+z')."""
225+
z_arr = xp.atleast_1d(z).astype(xp.float64)
226+
227+
u_max = xp.log(xp.asarray(1.0 + z_max))
228+
u_low = xp.log(1.0 + z_arr)
229+
u_grid = (
230+
u_low[:, None]
231+
+ (u_max - u_low)[:, None] * xp.linspace(0.0, 1.0, nz)[None, :]
232+
)
233+
zp = xp.exp(u_grid) - 1.0
234+
Ep = _E_lcdm(zp, Om0, Ode0, xp=xp)
235+
integrand = (1.0 + zp) ** 2 / Ep ** 3
236+
integral = xp.trapezoid(integrand, u_grid, axis=-1)
237+
238+
D = _E_lcdm(z_arr, Om0, Ode0, xp=xp) * integral
239+
if z_arr.shape == ():
240+
return D[0]
241+
return D
242+
243+
244+
def growth_factor(z, Om0, Ode0, xp=np, nz=256, z_max=1.0e4):
245+
"""D(z) / D(0), normalised growth factor for flat LCDM."""
246+
D_z = _growth_unnormalised(z, Om0, Ode0, xp=xp, nz=nz, z_max=z_max)
247+
D_0 = _growth_unnormalised(xp.asarray(0.0), Om0, Ode0, xp=xp, nz=nz, z_max=z_max)
248+
return D_z / D_0
249+
250+
251+
# ---------------------------------------------------------------------------
252+
# Einasto enclosed-mass ratio. For alpha = 0.18 (the value colossus uses
253+
# internally in modelLudlow16), M(<r_s) / M(<c r_s) = P(3/alpha, 2/alpha) /
254+
# P(3/alpha, (2/alpha) c^alpha), where P is the regularised lower incomplete
255+
# gamma function. Independent of cosmology and halo mass.
256+
# ---------------------------------------------------------------------------
257+
258+
259+
_EINASTO_ALPHA = 0.18
260+
261+
262+
def einasto_mass_ratio(c, xp=np, alpha=_EINASTO_ALPHA):
263+
"""M(<r_s) / M(<c r_s) for an Einasto profile, dimensionless."""
264+
s = 3.0 / alpha
265+
x_inner = 2.0 / alpha
266+
x_outer = 2.0 / alpha * c ** alpha
267+
return _gammainc(s, x_inner, xp=xp) / _gammainc(s, x_outer, xp=xp)
268+
269+
270+
# ---------------------------------------------------------------------------
271+
# Concentration solver — vectorised port of modelLudlow16.
272+
# ---------------------------------------------------------------------------
273+
274+
275+
_C_LUDLOW = 650.0
276+
_F_LUDLOW = 0.02
277+
_DELTA_COLLAPSE = 1.68647019984 # matches colossus.utils.constants.DELTA_COLLAPSE
278+
279+
280+
def _lagrangian_R(M, Om0, h, xp=np):
281+
"""Lagrangian radius for mass M (Msun/h) → R (Mpc/h)."""
282+
# Critical density today: 2.77536627e11 Msun h^2 / Mpc^3.
283+
rho_crit_0 = 2.77536627e11
284+
rho_m_0 = Om0 * rho_crit_0
285+
return (3.0 * M / (4.0 * xp.pi * rho_m_0)) ** (1.0 / 3.0)
286+
287+
288+
def ludlow16_concentration(
289+
M200c_Msun_per_h,
290+
z,
291+
h,
292+
Om0,
293+
Ob0,
294+
Tcmb0,
295+
sigma8,
296+
ns,
297+
xp=np,
298+
Ode0=None,
299+
c_array_size=200,
300+
sigma_nk=256,
301+
growth_nz=256,
302+
):
303+
"""
304+
JAX-native port of ``colossus.halo.concentration.modelLudlow16``.
305+
306+
Assumes flat LCDM (``Ode0 = 1 - Om0`` if not supplied) and ignores
307+
relativistic species, matching the analytic LCDM branch in colossus.
308+
309+
Parameters
310+
----------
311+
M200c_Msun_per_h : float or scalar xp array
312+
Halo mass in Msun/h.
313+
z : float or scalar xp array
314+
Redshift.
315+
h, Om0, Ob0, Tcmb0, sigma8, ns : float
316+
Cosmology parameters. See ``PLANCK15_COSMOLOGY`` for the values
317+
matching colossus' built-in ``planck15`` preset.
318+
xp : module
319+
Numerical backend — ``numpy`` or ``jax.numpy``.
320+
321+
Returns
322+
-------
323+
c200c : scalar xp array
324+
"""
325+
if Ode0 is None:
326+
Ode0 = 1.0 - Om0
327+
328+
M = xp.asarray(M200c_Msun_per_h, dtype=xp.float64)
329+
z = xp.asarray(z, dtype=xp.float64)
330+
331+
c_array = xp.logspace(0.0, 2.0, c_array_size)
332+
333+
M_ratio = einasto_mass_ratio(c_array, xp=xp)
334+
rho_f_rho_c = 200.0 * c_array ** 3 * M_ratio / _C_LUDLOW
335+
336+
# Formation redshift (closed-form LCDM); entries with t1 <= 0 are invalid
337+
# (low-c, where the formation redshift becomes < -1) and are masked below.
338+
t1 = (rho_f_rho_c * (Om0 * (1.0 + z) ** 3 + Ode0) - Ode0) / Om0
339+
valid_c = t1 > 0.0
340+
t1_safe = xp.where(valid_c, t1, 1.0)
341+
zf = t1_safe ** (1.0 / 3.0) - 1.0
342+
343+
R_fM = _lagrangian_R(_F_LUDLOW * M, Om0, h, xp=xp)
344+
R_M = _lagrangian_R(M, Om0, h, xp=xp)
345+
346+
sigma_fM = sigma_R(R_fM, h, Om0, Ob0, Tcmb0, sigma8, ns, xp=xp, nk=sigma_nk)
347+
sigma_M = sigma_R(R_M, h, Om0, Ob0, Tcmb0, sigma8, ns, xp=xp, nk=sigma_nk)
348+
sigma2_fM = sigma_fM ** 2
349+
sigma2_M = sigma_M ** 2
350+
351+
D_z = growth_factor(z, Om0, Ode0, xp=xp, nz=growth_nz)
352+
delta_z = _DELTA_COLLAPSE / D_z
353+
D_zf = growth_factor(zf, Om0, Ode0, xp=xp, nz=growth_nz)
354+
delta_zf = _DELTA_COLLAPSE / D_zf
355+
356+
arg = (delta_zf - delta_z) / xp.sqrt(2.0 * (sigma2_fM - sigma2_M))
357+
rhs = _erfc(arg, xp=xp)
358+
359+
# Solve M_ratio - rhs == 0 along c. Colossus trims c_array to entries
360+
# with t1 > 0 then np.interp; the un-trimmed array must remain monotonic
361+
# increasing in lhs_rhs for xp.interp. Pin invalid entries (low c) below
362+
# the lowest valid lhs_rhs (∈ [-1, 1]) so they sit at the bottom of xp.
363+
lhs_rhs = M_ratio - rhs
364+
lhs_rhs = xp.where(valid_c, lhs_rhs, -10.0)
365+
366+
return xp.interp(0.0, lhs_rhs, c_array)

0 commit comments

Comments
 (0)