Skip to content

Commit 277e0c5

Browse files
authored
Merge pull request #406 from PyAutoLabs/feature/ludlow16-jax-native
feat: replace Ludlow16 colossus pure_callback with JAX-native impl (#403)
2 parents ae7ebd0 + 9b50b07 commit 277e0c5

8 files changed

Lines changed: 737 additions & 306 deletions

File tree

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
"""
2+
JAX-native Ludlow et al. 2016 mass-concentration relation.
3+
4+
Replaces the previous ``jax.pure_callback``-wrapped call to
5+
``colossus.halo.concentration`` (formerly in ``mcr_util.py``). The
6+
algorithm follows colossus' ``modelLudlow16`` (``concentration.py``
7+
lines 1104-1192) and ``modelEisenstein98`` (``power_spectrum.py``
8+
lines 476-608) line-for-line, but every operation is done in pure
9+
``xp.*`` arithmetic so the same function runs under both numpy
10+
(``xp=np``) and JAX (``xp=jnp``).
11+
12+
Validated in PR #402 (Phase 1 feasibility): max relative error in
13+
c200c vs colossus over the lensing parameter grid
14+
(log M ∈ [10, 14] Msun/h, z ∈ [0.1, 2.5]) is **7.5 × 10⁻⁴**, with
15+
end-to-end downstream errors in convergence/deflection of ≤ 8 × 10⁻⁴.
16+
The 0.13 dex intrinsic Ludlow16 scatter is ~350× larger.
17+
18+
Units (matching colossus throughout):
19+
M200c is in Msun / h
20+
R is in Mpc / h
21+
z is dimensionless redshift
22+
"""
23+
24+
import numpy as np
25+
26+
27+
# Colossus' 'planck15' preset — the cosmology the previous callback used
28+
# internally to compute concentration. autogalaxy's own Planck15 (Om0=0.3075)
29+
# is used for everything else in mcr_util.ludlow16_cosmology; this constant is
30+
# only for the concentration call. Keeping the split is an apples-to-apples
31+
# swap with the previous behaviour; unifying the two is a separate decision.
32+
PLANCK15_COSMOLOGY = dict(
33+
h=0.6774,
34+
Om0=0.3089,
35+
Ob0=0.0486,
36+
Tcmb0=2.7255,
37+
sigma8=0.8159,
38+
ns=0.9667,
39+
)
40+
41+
42+
def _gammainc(a, x, xp):
43+
if xp is np:
44+
from scipy.special import gammainc
45+
else:
46+
from jax.scipy.special import gammainc
47+
return gammainc(a, x)
48+
49+
50+
def _erfc(arg, xp):
51+
if xp is np:
52+
from scipy.special import erfc
53+
else:
54+
from jax.scipy.special import erfc
55+
return erfc(arg)
56+
57+
58+
def _trapezoid_last_axis(y, x, xp):
59+
"""Trapezoidal-rule integration along the last axis of ``y``.
60+
61+
Equivalent to ``xp.trapezoid(y, x, axis=-1)`` but works on numpy
62+
versions before 1.26 (which only have the now-deprecated ``trapz``).
63+
Broadcasting follows the same rule: ``x`` may be 1-D (shared across
64+
the leading axes of ``y``) or fully-shaped to match ``y``.
65+
"""
66+
dx = x[..., 1:] - x[..., :-1]
67+
y_avg = 0.5 * (y[..., :-1] + y[..., 1:])
68+
return xp.sum(y_avg * dx, axis=-1)
69+
70+
71+
# ---------------------------------------------------------------------------
72+
# Eisenstein & Hu 1998 transfer function — direct port of
73+
# colossus.cosmology.power_spectrum.modelEisenstein98.
74+
# ---------------------------------------------------------------------------
75+
76+
77+
def transfer_eh98(k, h, Om0, Ob0, Tcmb0, xp=np):
78+
"""EH98 transfer function T(k) including baryon acoustic features."""
79+
omc = Om0 - Ob0
80+
ombom0 = Ob0 / Om0
81+
h2 = h ** 2
82+
om0h2 = Om0 * h2
83+
ombh2 = Ob0 * h2
84+
theta2p7 = Tcmb0 / 2.7
85+
theta2p72 = theta2p7 ** 2
86+
theta2p74 = theta2p72 ** 2
87+
88+
kh = k * h
89+
90+
zeq = 2.50e4 * om0h2 / theta2p74
91+
keq = 7.46e-2 * om0h2 / theta2p72
92+
93+
b1d = 0.313 * om0h2 ** -0.419 * (1.0 + 0.607 * om0h2 ** 0.674)
94+
b2d = 0.238 * om0h2 ** 0.223
95+
zd = 1291.0 * om0h2 ** 0.251 / (1.0 + 0.659 * om0h2 ** 0.828) * (
96+
1.0 + b1d * ombh2 ** b2d
97+
)
98+
99+
Rd = 31.5 * ombh2 / theta2p74 / (zd / 1e3)
100+
Req = 31.5 * ombh2 / theta2p74 / (zeq / 1e3)
101+
102+
s = (
103+
2.0
104+
/ 3.0
105+
/ keq
106+
* xp.sqrt(6.0 / Req)
107+
* xp.log((xp.sqrt(1.0 + Rd) + xp.sqrt(Rd + Req)) / (1.0 + xp.sqrt(Req)))
108+
)
109+
110+
ksilk = 1.6 * ombh2 ** 0.52 * om0h2 ** 0.73 * (1.0 + (10.4 * om0h2) ** -0.95)
111+
112+
q = kh / 13.41 / keq
113+
114+
a1 = (46.9 * om0h2) ** 0.670 * (1.0 + (32.1 * om0h2) ** -0.532)
115+
a2 = (12.0 * om0h2) ** 0.424 * (1.0 + (45.0 * om0h2) ** -0.582)
116+
ac = a1 ** (-ombom0) * a2 ** (-(ombom0 ** 3))
117+
118+
b1 = 0.944 / (1.0 + (458.0 * om0h2) ** -0.708)
119+
b2 = (0.395 * om0h2) ** -0.0266
120+
bc = 1.0 / (1.0 + b1 * ((omc / Om0) ** b2 - 1.0))
121+
122+
y = (1.0 + zeq) / (1.0 + zd)
123+
Gy = y * (
124+
-6.0 * xp.sqrt(1.0 + y)
125+
+ (2.0 + 3.0 * y)
126+
* xp.log((xp.sqrt(1.0 + y) + 1.0) / (xp.sqrt(1.0 + y) - 1.0))
127+
)
128+
129+
ab = 2.07 * keq * s * (1.0 + Rd) ** (-3.0 / 4.0) * Gy
130+
131+
f = 1.0 / (1.0 + (kh * s / 5.4) ** 4)
132+
133+
C = 14.2 / ac + 386.0 / (1.0 + 69.9 * q ** 1.08)
134+
T0t = xp.log(xp.e + 1.8 * bc * q) / (xp.log(xp.e + 1.8 * bc * q) + C * q * q)
135+
136+
C1bc = 14.2 + 386.0 / (1.0 + 69.9 * q ** 1.08)
137+
T0t1bc = xp.log(xp.e + 1.8 * bc * q) / (
138+
xp.log(xp.e + 1.8 * bc * q) + C1bc * q * q
139+
)
140+
Tc = f * T0t1bc + (1.0 - f) * T0t
141+
142+
bb = (
143+
0.5
144+
+ ombom0
145+
+ (3.0 - 2.0 * ombom0)
146+
* xp.sqrt((17.2 * om0h2) * (17.2 * om0h2) + 1.0)
147+
)
148+
149+
bnode = 8.41 * om0h2 ** 0.435
150+
151+
st = s / (1.0 + (bnode / kh / s) * (bnode / kh / s) * (bnode / kh / s)) ** (
152+
1.0 / 3.0
153+
)
154+
155+
C11 = 14.2 + 386.0 / (1.0 + 69.9 * q ** 1.08)
156+
T0t11 = xp.log(xp.e + 1.8 * q) / (xp.log(xp.e + 1.8 * q) + C11 * q * q)
157+
Tb = (
158+
T0t11 / (1.0 + (kh * s / 5.2) ** 2)
159+
+ ab / (1.0 + (bb / kh / s) ** 3) * xp.exp(-((kh / ksilk) ** 1.4))
160+
) * xp.sin(kh * st) / (kh * st)
161+
162+
return ombom0 * Tb + omc / Om0 * Tc
163+
164+
165+
# ---------------------------------------------------------------------------
166+
# sigma(R, z=0): RMS of mass within top-hat radius R, normalised to sigma8.
167+
# ---------------------------------------------------------------------------
168+
169+
170+
def _tophat_window(x, xp=np):
171+
"""W(x) = 3 (sin x - x cos x) / x^3, with a safe small-x expansion."""
172+
small = x < 1.0e-3
173+
x2 = x * x
174+
safe_small = 1.0 - x2 / 10.0 + x2 * x2 / 280.0
175+
safe_large = 3.0 * (xp.sin(x) - x * xp.cos(x)) / xp.where(small, 1.0, x ** 3)
176+
return xp.where(small, safe_small, safe_large)
177+
178+
179+
def _sigma2_unnormalised(
180+
R, h, Om0, Ob0, Tcmb0, ns,
181+
xp=np,
182+
k_log_min=-5.0, k_log_max=3.0, nk=256,
183+
):
184+
"""sigma^2(R) at z=0 for an unnormalised power spectrum P(k) = k^ns T(k)^2."""
185+
ln_k = xp.linspace(
186+
k_log_min * xp.log(xp.asarray(10.0)),
187+
k_log_max * xp.log(xp.asarray(10.0)),
188+
nk,
189+
)
190+
k = xp.exp(ln_k)
191+
192+
Tk = transfer_eh98(k, h, Om0, Ob0, Tcmb0, xp=xp)
193+
Pk_unnorm = k ** ns * Tk ** 2
194+
195+
R = xp.atleast_1d(R)
196+
kR = k[None, :] * R[:, None]
197+
W = _tophat_window(kR, xp=xp)
198+
199+
integrand = k[None, :] ** 3 * Pk_unnorm[None, :] * W ** 2
200+
integrand = integrand / (2.0 * xp.pi ** 2)
201+
202+
sigma2 = _trapezoid_last_axis(integrand, ln_k, xp=xp)
203+
if sigma2.shape == (1,):
204+
return sigma2[0]
205+
return sigma2
206+
207+
208+
def sigma_R(
209+
R, h, Om0, Ob0, Tcmb0, sigma8, ns,
210+
xp=np,
211+
k_log_min=-5.0, k_log_max=3.0, nk=256,
212+
):
213+
"""sigma(R, z=0), normalised so sigma(R=8 Mpc/h) = sigma8."""
214+
sigma2_unnorm = _sigma2_unnormalised(
215+
R, h, Om0, Ob0, Tcmb0, ns,
216+
xp=xp, k_log_min=k_log_min, k_log_max=k_log_max, nk=nk,
217+
)
218+
sigma2_8_unnorm = _sigma2_unnormalised(
219+
xp.asarray(8.0), h, Om0, Ob0, Tcmb0, ns,
220+
xp=xp, k_log_min=k_log_min, k_log_max=k_log_max, nk=nk,
221+
)
222+
norm = sigma8 ** 2 / sigma2_8_unnorm
223+
return xp.sqrt(norm * sigma2_unnorm)
224+
225+
226+
# ---------------------------------------------------------------------------
227+
# Linear growth factor D(z), flat LCDM (no relspecies).
228+
# Integral form (Eisenstein & Hu 1999 Eq. 8 / Heath 1977).
229+
# ---------------------------------------------------------------------------
230+
231+
232+
def _E_lcdm(z, Om0, Ode0, xp=np):
233+
return xp.sqrt(Om0 * (1.0 + z) ** 3 + Ode0)
234+
235+
236+
def _growth_unnormalised(z, Om0, Ode0, xp=np, nz=256, z_max=1.0e4):
237+
"""D_+(z) un-normalised. Integrate (1+z') / E(z')^3 from z to z_max via u=ln(1+z')."""
238+
z_arr = xp.atleast_1d(z).astype(xp.float64)
239+
240+
u_max = xp.log(xp.asarray(1.0 + z_max))
241+
u_low = xp.log(1.0 + z_arr)
242+
u_grid = (
243+
u_low[:, None]
244+
+ (u_max - u_low)[:, None] * xp.linspace(0.0, 1.0, nz)[None, :]
245+
)
246+
zp = xp.exp(u_grid) - 1.0
247+
Ep = _E_lcdm(zp, Om0, Ode0, xp=xp)
248+
integrand = (1.0 + zp) ** 2 / Ep ** 3
249+
integral = _trapezoid_last_axis(integrand, u_grid, xp=xp)
250+
251+
D = _E_lcdm(z_arr, Om0, Ode0, xp=xp) * integral
252+
if z_arr.shape == ():
253+
return D[0]
254+
return D
255+
256+
257+
def growth_factor(z, Om0, Ode0, xp=np, nz=256, z_max=1.0e4):
258+
"""D(z) / D(0), normalised growth factor for flat LCDM."""
259+
D_z = _growth_unnormalised(z, Om0, Ode0, xp=xp, nz=nz, z_max=z_max)
260+
D_0 = _growth_unnormalised(xp.asarray(0.0), Om0, Ode0, xp=xp, nz=nz, z_max=z_max)
261+
return D_z / D_0
262+
263+
264+
# ---------------------------------------------------------------------------
265+
# Einasto enclosed-mass ratio. For alpha = 0.18 (the value colossus uses
266+
# internally in modelLudlow16), M(<r_s) / M(<c r_s) = P(3/alpha, 2/alpha) /
267+
# P(3/alpha, (2/alpha) c^alpha), where P is the regularised lower incomplete
268+
# gamma function. Independent of cosmology and halo mass.
269+
# ---------------------------------------------------------------------------
270+
271+
272+
_EINASTO_ALPHA = 0.18
273+
274+
275+
def einasto_mass_ratio(c, xp=np, alpha=_EINASTO_ALPHA):
276+
"""M(<r_s) / M(<c r_s) for an Einasto profile, dimensionless."""
277+
s = 3.0 / alpha
278+
x_inner = 2.0 / alpha
279+
x_outer = 2.0 / alpha * c ** alpha
280+
return _gammainc(s, x_inner, xp=xp) / _gammainc(s, x_outer, xp=xp)
281+
282+
283+
# ---------------------------------------------------------------------------
284+
# Concentration solver — vectorised port of modelLudlow16.
285+
# ---------------------------------------------------------------------------
286+
287+
288+
_C_LUDLOW = 650.0
289+
_F_LUDLOW = 0.02
290+
_DELTA_COLLAPSE = 1.68647019984 # matches colossus.utils.constants.DELTA_COLLAPSE
291+
292+
293+
def _lagrangian_R(M, Om0, h, xp=np):
294+
"""Lagrangian radius for mass M (Msun/h) → R (Mpc/h)."""
295+
# Critical density today: 2.77536627e11 Msun h^2 / Mpc^3.
296+
rho_crit_0 = 2.77536627e11
297+
rho_m_0 = Om0 * rho_crit_0
298+
return (3.0 * M / (4.0 * xp.pi * rho_m_0)) ** (1.0 / 3.0)
299+
300+
301+
def ludlow16_concentration(
302+
M200c_Msun_per_h,
303+
z,
304+
h,
305+
Om0,
306+
Ob0,
307+
Tcmb0,
308+
sigma8,
309+
ns,
310+
xp=np,
311+
Ode0=None,
312+
c_array_size=200,
313+
sigma_nk=256,
314+
growth_nz=256,
315+
):
316+
"""
317+
JAX-native port of ``colossus.halo.concentration.modelLudlow16``.
318+
319+
Assumes flat LCDM (``Ode0 = 1 - Om0`` if not supplied) and ignores
320+
relativistic species, matching the analytic LCDM branch in colossus.
321+
322+
Parameters
323+
----------
324+
M200c_Msun_per_h : float or scalar xp array
325+
Halo mass in Msun/h.
326+
z : float or scalar xp array
327+
Redshift.
328+
h, Om0, Ob0, Tcmb0, sigma8, ns : float
329+
Cosmology parameters. See ``PLANCK15_COSMOLOGY`` for the values
330+
matching colossus' built-in ``planck15`` preset.
331+
xp : module
332+
Numerical backend — ``numpy`` or ``jax.numpy``.
333+
334+
Returns
335+
-------
336+
c200c : scalar xp array
337+
"""
338+
if Ode0 is None:
339+
Ode0 = 1.0 - Om0
340+
341+
M = xp.asarray(M200c_Msun_per_h, dtype=xp.float64)
342+
z = xp.asarray(z, dtype=xp.float64)
343+
344+
c_array = xp.logspace(0.0, 2.0, c_array_size)
345+
346+
M_ratio = einasto_mass_ratio(c_array, xp=xp)
347+
rho_f_rho_c = 200.0 * c_array ** 3 * M_ratio / _C_LUDLOW
348+
349+
# Formation redshift (closed-form LCDM); entries with t1 <= 0 are invalid
350+
# (low-c, where the formation redshift becomes < -1) and are masked below.
351+
t1 = (rho_f_rho_c * (Om0 * (1.0 + z) ** 3 + Ode0) - Ode0) / Om0
352+
valid_c = t1 > 0.0
353+
t1_safe = xp.where(valid_c, t1, 1.0)
354+
zf = t1_safe ** (1.0 / 3.0) - 1.0
355+
356+
R_fM = _lagrangian_R(_F_LUDLOW * M, Om0, h, xp=xp)
357+
R_M = _lagrangian_R(M, Om0, h, xp=xp)
358+
359+
sigma_fM = sigma_R(R_fM, h, Om0, Ob0, Tcmb0, sigma8, ns, xp=xp, nk=sigma_nk)
360+
sigma_M = sigma_R(R_M, h, Om0, Ob0, Tcmb0, sigma8, ns, xp=xp, nk=sigma_nk)
361+
sigma2_fM = sigma_fM ** 2
362+
sigma2_M = sigma_M ** 2
363+
364+
D_z = growth_factor(z, Om0, Ode0, xp=xp, nz=growth_nz)
365+
delta_z = _DELTA_COLLAPSE / D_z
366+
D_zf = growth_factor(zf, Om0, Ode0, xp=xp, nz=growth_nz)
367+
delta_zf = _DELTA_COLLAPSE / D_zf
368+
369+
arg = (delta_zf - delta_z) / xp.sqrt(2.0 * (sigma2_fM - sigma2_M))
370+
rhs = _erfc(arg, xp=xp)
371+
372+
# Solve M_ratio - rhs == 0 along c. Colossus trims c_array to entries
373+
# with t1 > 0 then np.interp; the un-trimmed array must remain monotonic
374+
# increasing in lhs_rhs for xp.interp. Pin invalid entries (low c) below
375+
# the lowest valid lhs_rhs (∈ [-1, 1]) so they sit at the bottom of xp.
376+
lhs_rhs = M_ratio - rhs
377+
lhs_rhs = xp.where(valid_c, lhs_rhs, -10.0)
378+
379+
return xp.interp(0.0, lhs_rhs, c_array)

0 commit comments

Comments
 (0)