Skip to content

Commit c174bc4

Browse files
Jammy2211Jammy2211
authored andcommitted
hack to determine xp on the fly, until I make higher level decisions architectuallu
1 parent b9a2e50 commit c174bc4

1 file changed

Lines changed: 159 additions & 31 deletions

File tree

autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py

Lines changed: 159 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,181 @@
55
import numpy as np
66
from autogalaxy import cosmology as cosmo
77

8+
def is_jax(x):
9+
try:
10+
import jax
11+
from jax import Array
12+
from jax.core import Tracer
13+
return isinstance(x, (Array, Tracer))
14+
except Exception:
15+
return False
16+
17+
def _hyp2f1_jax(xp, *, max_terms: int = 256):
18+
"""
19+
Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp.
20+
21+
- NumPy: scipy.special.hyp2f1
22+
- JAX (if available): jax.scipy.special.hyp2f1
23+
- JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case)
24+
"""
25+
import jax
26+
import jax.numpy as jnp
27+
28+
# Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z)
29+
# We implement general 2F1 series:
30+
# 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n!
31+
#
32+
# Recurrence for terms:
33+
# t_0 = 1
34+
# t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z
35+
#
36+
# This is JIT-safe with static max_terms.
37+
def hyp2f1_series(a, b, c, z):
38+
a = jnp.asarray(a)
39+
b = jnp.asarray(b)
40+
c = jnp.asarray(c)
41+
z = jnp.asarray(z)
42+
43+
def body_fun(n, carry):
44+
t, s = carry
45+
n_f = jnp.asarray(n, dtype=t.dtype)
46+
t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z
47+
s = s + t
48+
return (t, s)
49+
50+
# Start: t0 = 1, s0 = 1
51+
t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z))
52+
s0 = t0
53+
54+
# fori_loop has static iteration count => good under jit/vmap
55+
tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0))
56+
return sN
57+
58+
return hyp2f1_series
859

960
def kappa_s_and_scale_radius(
10-
cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope
61+
cosmology,
62+
virial_mass,
63+
c_2,
64+
overdens,
65+
redshift_object,
66+
redshift_source,
67+
inner_slope,
1168
):
12-
from scipy.integrate import quad
69+
"""
70+
Compute the characteristic convergence and scale radius of a spherical gNFW halo
71+
parameterised by virial mass and concentration.
1372
14-
concentration = (2.0 - inner_slope) * c_2 # gNFW concentration
73+
This routine converts a halo defined by its virial mass and concentration into
74+
the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing
75+
calculations. The normalization is computed analytically using the closed-form
76+
hypergeometric expression for the enclosed mass integral, ensuring compatibility
77+
with both NumPy and JAX backends (e.g. within `jax.jit`).
1578
16-
critical_density = cosmology.critical_density(
17-
redshift_object, xp=np
18-
) # Msun / kpc^3
79+
The virial radius is defined via:
1980
20-
critical_surface_density = (
21-
cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
22-
redshift_0=redshift_object,
23-
redshift_1=redshift_source,
24-
xp=np,
25-
)
81+
M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3
82+
83+
where Δ is the overdensity with respect to the critical density. If `overdens`
84+
is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used.
85+
86+
The gNFW normalization constant is computed as:
87+
88+
d_e = (Δ / 3) (3 − γ) c^γ /
89+
₂F₁(3 − γ, 3 − γ; 4 − γ; −c)
90+
91+
where γ is the inner slope and c is the gNFW concentration.
92+
93+
Parameters
94+
----------
95+
cosmology
96+
Cosmology object providing critical density, angular diameter distance
97+
conversions, and surface mass density calculations. Must support an `xp`
98+
argument for NumPy/JAX interoperability.
99+
virial_mass
100+
Virial mass of the halo in units of solar masses.
101+
c_2
102+
Concentration-like parameter, converted internally to the gNFW
103+
concentration via `(2 - inner_slope) * c_2`.
104+
overdens
105+
Overdensity with respect to the critical density. If zero, the
106+
Bryan & Norman (1998) redshift-dependent overdensity is used.
107+
redshift_object
108+
Redshift of the lens (halo).
109+
redshift_source
110+
Redshift of the background source.
111+
inner_slope
112+
Inner logarithmic density slope γ of the gNFW profile.
113+
xp
114+
Array backend module (`numpy` or `jax.numpy`). All array operations
115+
are dispatched through this module to ensure compatibility with
116+
both standard NumPy execution and JAX tracing / JIT compilation.
117+
118+
Returns
119+
-------
120+
kappa_s
121+
Dimensionless characteristic convergence of the gNFW profile.
122+
scale_radius
123+
Angular scale radius in arcseconds.
124+
virial_radius
125+
Virial radius in kiloparsecs.
126+
overdens
127+
Final overdensity value used in the calculation.
128+
129+
Notes
130+
-----
131+
- This implementation is fully JIT-compatible when `xp=jax.numpy`.
132+
- No Python-side branching depends on traced values; conditional logic
133+
is implemented via backend array operations.
134+
- The analytic normalization avoids numerical quadrature, improving
135+
both performance and differentiability.
136+
"""
137+
is_jax_bool = is_jax(virial_mass)
138+
139+
if not is_jax_bool:
140+
xp = np
141+
else:
142+
from jax import numpy as jnp
143+
xp = jnp
144+
145+
if xp is np:
146+
from scipy.special import hyp2f1
147+
else:
148+
try:
149+
from jax.scipy.special import hyp2f1
150+
except ImportError:
151+
hyp2f1 = _hyp2f1_jax(xp)
152+
153+
gamma = inner_slope
154+
concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition)
155+
156+
critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3
157+
158+
critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
159+
redshift_0=redshift_object,
160+
redshift_1=redshift_source,
161+
xp=xp,
26162
) # Msun / kpc^2
27163

28-
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(
29-
redshift=redshift_object, xp=np
30-
) # kpc / arcsec
164+
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec
31165

32-
if overdens == 0:
33-
x = cosmology.Om(redshift_object, xp=np) - 1.0
34-
overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998)
166+
# Bryan & Norman (1998) overdensity if overdens == 0
167+
x = cosmology.Om(redshift_object, xp=xp) - 1.0
168+
overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2
169+
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
35170

36171
# r_vir in kpc
37-
virial_radius = (
38-
virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0))
39-
) ** (1.0 / 3.0)
172+
virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0)
40173

41174
# scale radius in kpc
42175
scale_radius_kpc = virial_radius / concentration
43176

44-
# Normalization integral for gNFW
45-
def integrand(r):
46-
return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** (
47-
inner_slope - 3.0
48-
)
177+
# c = rvir/rs is exactly "concentration" by definition
178+
c = concentration
49179

50-
de_c = (
51-
(overdens / 3.0)
52-
* (virial_radius**3 / scale_radius_kpc**inner_slope)
53-
/ quad(integrand, 0.0, virial_radius)[0]
54-
)
180+
# Analytic normalization
181+
a = 3.0 - gamma
182+
de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c)
55183

56184
rho_s = critical_density * de_c # Msun / kpc^3
57185
kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless

0 commit comments

Comments
 (0)