Skip to content

Commit 9d1b9d8

Browse files
Jammy2211Jammy2211
authored andcommitted
require JAX version update for hypf1
1 parent c174bc4 commit 9d1b9d8

1 file changed

Lines changed: 6 additions & 46 deletions

File tree

autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -14,49 +14,6 @@ def is_jax(x):
1414
except Exception:
1515
return False
1616

17-
def _hyp2f1_jax(xp, *, max_terms: int = 256):
18-
"""
19-
Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp.
20-
21-
- NumPy: scipy.special.hyp2f1
22-
- JAX (if available): jax.scipy.special.hyp2f1
23-
- JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case)
24-
"""
25-
import jax
26-
import jax.numpy as jnp
27-
28-
# Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z)
29-
# We implement general 2F1 series:
30-
# 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n!
31-
#
32-
# Recurrence for terms:
33-
# t_0 = 1
34-
# t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z
35-
#
36-
# This is JIT-safe with static max_terms.
37-
def hyp2f1_series(a, b, c, z):
38-
a = jnp.asarray(a)
39-
b = jnp.asarray(b)
40-
c = jnp.asarray(c)
41-
z = jnp.asarray(z)
42-
43-
def body_fun(n, carry):
44-
t, s = carry
45-
n_f = jnp.asarray(n, dtype=t.dtype)
46-
t = t * (a + n_f) * (b + n_f) / ((c + n_f) * (n_f + 1.0)) * z
47-
s = s + t
48-
return (t, s)
49-
50-
# Start: t0 = 1, s0 = 1
51-
t0 = jnp.ones_like(z, dtype=jnp.result_type(a, b, c, z))
52-
s0 = t0
53-
54-
# fori_loop has static iteration count => good under jit/vmap
55-
tN, sN = jax.lax.fori_loop(0, max_terms - 1, body_fun, (t0, s0))
56-
return sN
57-
58-
return hyp2f1_series
59-
6017
def kappa_s_and_scale_radius(
6118
cosmology,
6219
virial_mass,
@@ -146,9 +103,12 @@ def kappa_s_and_scale_radius(
146103
from scipy.special import hyp2f1
147104
else:
148105
try:
149-
from jax.scipy.special import hyp2f1
150-
except ImportError:
151-
hyp2f1 = _hyp2f1_jax(xp)
106+
from jax.scipy.special import hyp2f1 # noqa: F401
107+
except Exception as e:
108+
raise RuntimeError(
109+
"This feature requires jax.scipy.special.hyp2f1, which is available in "
110+
"JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`."
111+
) from e
152112

153113
gamma = inner_slope
154114
concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition)

0 commit comments

Comments
 (0)