|
5 | 5 | import numpy as np |
6 | 6 | from autogalaxy import cosmology as cosmo |
7 | 7 |
|
| 8 | +def is_jax(x): |
| 9 | + try: |
| 10 | + import jax |
| 11 | + from jax import Array |
| 12 | + from jax.core import Tracer |
| 13 | + return isinstance(x, (Array, Tracer)) |
| 14 | + except Exception: |
| 15 | + return False |
| 16 | + |
| 17 | +def _hyp2f1_jax(xp, *, max_terms: int = 256): |
| 18 | + """ |
| 19 | + Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp. |
| 20 | +
|
| 21 | + - NumPy: scipy.special.hyp2f1 |
| 22 | + - JAX (if available): jax.scipy.special.hyp2f1 |
| 23 | + - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case) |
| 24 | + """ |
| 25 | + import jax |
| 26 | + import jax.numpy as jnp |
| 27 | + |
| 28 | + # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z) |
| 29 | + # We implement general 2F1 series: |
| 30 | + # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n! |
| 31 | + # |
| 32 | + # Recurrence for terms: |
| 33 | + # t_0 = 1 |
| 34 | + # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z |
| 35 | + # |
| 36 | + # This is JIT-safe with static max_terms. |
| 37 | + def hyp2f1_series(a, b, c, z): |
| 38 | + a = jnp.asarray(a) |
| 39 | + b = jnp.asarray(b) |
| 40 | + c = jnp.asarray(c) |
| 41 | + z = jnp.asarray(z) |
| 42 | + |
| 43 | + def body_fun(n, carry): |
| 44 | + t, s = carry |
| 45 | + n_f = jnp.asarray(n, dtype=t.dtype) |
| 46 | + t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z |
| 47 | + s = s + t |
| 48 | + return (t, s) |
| 49 | + |
| 50 | + # Start: t0 = 1, s0 = 1 |
| 51 | + t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z)) |
| 52 | + s0 = t0 |
| 53 | + |
| 54 | + # fori_loop has static iteration count => good under jit/vmap |
| 55 | + tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0)) |
| 56 | + return sN |
| 57 | + |
| 58 | + return hyp2f1_series |
8 | 59 |
|
9 | 60 | def kappa_s_and_scale_radius( |
10 | | - cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope |
| 61 | + cosmology, |
| 62 | + virial_mass, |
| 63 | + c_2, |
| 64 | + overdens, |
| 65 | + redshift_object, |
| 66 | + redshift_source, |
| 67 | + inner_slope, |
11 | 68 | ): |
12 | | - from scipy.integrate import quad |
| 69 | + """ |
| 70 | + Compute the characteristic convergence and scale radius of a spherical gNFW halo |
| 71 | + parameterised by virial mass and concentration. |
13 | 72 |
|
14 | | - concentration = (2.0 - inner_slope) * c_2 # gNFW concentration |
| 73 | + This routine converts a halo defined by its virial mass and concentration into |
| 74 | + the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing |
| 75 | + calculations. The normalization is computed analytically using the closed-form |
| 76 | + hypergeometric expression for the enclosed mass integral, ensuring compatibility |
| 77 | + with both NumPy and JAX backends (e.g. within `jax.jit`). |
15 | 78 |
|
16 | | - critical_density = cosmology.critical_density( |
17 | | - redshift_object, xp=np |
18 | | - ) # Msun / kpc^3 |
| 79 | + The virial radius is defined via: |
19 | 80 |
|
20 | | - critical_surface_density = ( |
21 | | - cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( |
22 | | - redshift_0=redshift_object, |
23 | | - redshift_1=redshift_source, |
24 | | - xp=np, |
25 | | - ) |
| 81 | + M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3 |
| 82 | +
|
| 83 | + where Δ is the overdensity with respect to the critical density. If `overdens` |
| 84 | + is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used. |
| 85 | +
|
| 86 | + The gNFW normalization constant is computed as: |
| 87 | +
|
| 88 | + d_e = (Δ / 3) (3 − γ) c^γ / |
| 89 | + ₂F₁(3 − γ, 3 − γ; 4 − γ; −c) |
| 90 | +
|
| 91 | + where γ is the inner slope and c is the gNFW concentration. |
| 92 | +
|
| 93 | + Parameters |
| 94 | + ---------- |
| 95 | + cosmology |
| 96 | + Cosmology object providing critical density, angular diameter distance |
| 97 | + conversions, and surface mass density calculations. Must support an `xp` |
| 98 | + argument for NumPy/JAX interoperability. |
| 99 | + virial_mass |
| 100 | + Virial mass of the halo in units of solar masses. |
| 101 | + c_2 |
| 102 | + Concentration-like parameter, converted internally to the gNFW |
| 103 | + concentration via `(2 - inner_slope) * c_2`. |
| 104 | + overdens |
| 105 | + Overdensity with respect to the critical density. If zero, the |
| 106 | + Bryan & Norman (1998) redshift-dependent overdensity is used. |
| 107 | + redshift_object |
| 108 | + Redshift of the lens (halo). |
| 109 | + redshift_source |
| 110 | + Redshift of the background source. |
| 111 | + inner_slope |
| 112 | + Inner logarithmic density slope γ of the gNFW profile. |
| 113 | + xp |
| 114 | + Array backend module (`numpy` or `jax.numpy`). All array operations |
| 115 | + are dispatched through this module to ensure compatibility with |
| 116 | + both standard NumPy execution and JAX tracing / JIT compilation. |
| 117 | +
|
| 118 | + Returns |
| 119 | + ------- |
| 120 | + kappa_s |
| 121 | + Dimensionless characteristic convergence of the gNFW profile. |
| 122 | + scale_radius |
| 123 | + Angular scale radius in arcseconds. |
| 124 | + virial_radius |
| 125 | + Virial radius in kiloparsecs. |
| 126 | + overdens |
| 127 | + Final overdensity value used in the calculation. |
| 128 | +
|
| 129 | + Notes |
| 130 | + ----- |
| 131 | + - This implementation is fully JIT-compatible when `xp=jax.numpy`. |
| 132 | + - No Python-side branching depends on traced values; conditional logic |
| 133 | + is implemented via backend array operations. |
| 134 | + - The analytic normalization avoids numerical quadrature, improving |
| 135 | + both performance and differentiability. |
| 136 | + """ |
| 137 | + is_jax_bool = is_jax(virial_mass) |
| 138 | + |
| 139 | + if not is_jax_bool: |
| 140 | + xp = np |
| 141 | + else: |
| 142 | + from jax import numpy as jnp |
| 143 | + xp = jnp |
| 144 | + |
| 145 | + if xp is np: |
| 146 | + from scipy.special import hyp2f1 |
| 147 | + else: |
| 148 | + try: |
| 149 | + from jax.scipy.special import hyp2f1 |
| 150 | + except ImportError: |
| 151 | + hyp2f1 = _hyp2f1_jax(xp) |
| 152 | + |
| 153 | + gamma = inner_slope |
| 154 | + concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition) |
| 155 | + |
| 156 | + critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3 |
| 157 | + |
| 158 | + critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( |
| 159 | + redshift_0=redshift_object, |
| 160 | + redshift_1=redshift_source, |
| 161 | + xp=xp, |
26 | 162 | ) # Msun / kpc^2 |
27 | 163 |
|
28 | | - kpc_per_arcsec = cosmology.kpc_per_arcsec_from( |
29 | | - redshift=redshift_object, xp=np |
30 | | - ) # kpc / arcsec |
| 164 | + kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec |
31 | 165 |
|
32 | | - if overdens == 0: |
33 | | - x = cosmology.Om(redshift_object, xp=np) - 1.0 |
34 | | - overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998) |
| 166 | + # Bryan & Norman (1998) overdensity if overdens == 0 |
| 167 | + x = cosmology.Om(redshift_object, xp=xp) - 1.0 |
| 168 | + overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 |
| 169 | + overdens = xp.where(overdens == 0, overdens_bn98, overdens) |
35 | 170 |
|
36 | 171 | # r_vir in kpc |
37 | | - virial_radius = ( |
38 | | - virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0)) |
39 | | - ) ** (1.0 / 3.0) |
| 172 | + virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0) |
40 | 173 |
|
41 | 174 | # scale radius in kpc |
42 | 175 | scale_radius_kpc = virial_radius / concentration |
43 | 176 |
|
44 | | - # Normalization integral for gNFW |
45 | | - def integrand(r): |
46 | | - return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** ( |
47 | | - inner_slope - 3.0 |
48 | | - ) |
| 177 | + # c = rvir/rs is exactly "concentration" by definition |
| 178 | + c = concentration |
49 | 179 |
|
50 | | - de_c = ( |
51 | | - (overdens / 3.0) |
52 | | - * (virial_radius**3 / scale_radius_kpc**inner_slope) |
53 | | - / quad(integrand, 0.0, virial_radius)[0] |
54 | | - ) |
| 180 | + # Analytic normalization |
| 181 | + a = 3.0 - gamma |
| 182 | + de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c) |
55 | 183 |
|
56 | 184 | rho_s = critical_density * de_c # Msun / kpc^3 |
57 | 185 | kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless |
|
0 commit comments