@@ -14,49 +14,6 @@ def is_jax(x):
1414 except Exception :
1515 return False
1616
17- def _hyp2f1_jax (xp , * , max_terms : int = 256 ):
18- """
19- Returns a callable hyp2f1(a,b,c,z) compatible with the backend xp.
20-
21- - NumPy: scipy.special.hyp2f1
22- - JAX (if available): jax.scipy.special.hyp2f1
23- - JAX (fallback): series approximation for 2F1 (sufficient for this gNFW use-case)
24- """
25- import jax
26- import jax .numpy as jnp
27-
28- # Fallback: truncated series for 2F1(a,a;a+1;z) and general 2F1(a,b;c;z)
29- # We implement general 2F1 series:
30- # 2F1(a,b;c;z) = sum_{n=0}^{∞} (a)_n (b)_n / (c)_n * z^n / n!
31- #
32- # Recurrence for terms:
33- # t_0 = 1
34- # t_{n+1} = t_n * (a+n)(b+n)/((c+n)(n+1)) * z
35- #
36- # This is JIT-safe with static max_terms.
37- def hyp2f1_series (a , b , c , z ):
38- a = jnp .asarray (a )
39- b = jnp .asarray (b )
40- c = jnp .asarray (c )
41- z = jnp .asarray (z )
42-
43- def body_fun (n , carry ):
44- t , s = carry
45- n_f = jnp .asarray (n , dtype = t .dtype )
46- t = t * (a + n_f ) * (b + n_f ) / ((c + n_f ) * (n_f + 1.0 )) * z
47- s = s + t
48- return (t , s )
49-
50- # Start: t0 = 1, s0 = 1
51- t0 = jnp .ones_like (z , dtype = jnp .result_type (a , b , c , z ))
52- s0 = t0
53-
54- # fori_loop has static iteration count => good under jit/vmap
55- tN , sN = jax .lax .fori_loop (0 , max_terms - 1 , body_fun , (t0 , s0 ))
56- return sN
57-
58- return hyp2f1_series
59-
6017def kappa_s_and_scale_radius (
6118 cosmology ,
6219 virial_mass ,
@@ -146,9 +103,12 @@ def kappa_s_and_scale_radius(
146103 from scipy .special import hyp2f1
147104 else :
148105 try :
149- from jax .scipy .special import hyp2f1
150- except ImportError :
151- hyp2f1 = _hyp2f1_jax (xp )
106+ from jax .scipy .special import hyp2f1 # noqa: F401
107+ except Exception as e :
108+ raise RuntimeError (
109+ "This feature requires jax.scipy.special.hyp2f1, which is available in "
110+ "JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`."
111+ ) from e
152112
153113 gamma = inner_slope
154114 concentration = (2.0 - gamma ) * c_2 # gNFW concentration (your definition)
0 commit comments