Skip to content

Commit 7f5dc8d

Browse files
authored
Merge pull request #402 from PyAutoLabs/feature/nfw-jax-port
docs(research): Ludlow16 JAX concentration feasibility study (#397)
2 parents ee7c989 + 71b0425 commit 7f5dc8d

6 files changed

Lines changed: 1458 additions & 0 deletions

File tree

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
"""
2+
Benchmark the JAX prototype against colossus, and exercise grad + vmap
3+
to confirm the JAX path supports differentiation and vectorisation —
4+
the two capabilities the pure_callback rules out.
5+
6+
Run with:
7+
NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib \
8+
JAX_ENABLE_X64=1 python docs/research/nfw_ludlow16_jax/bench.py
9+
"""
10+
11+
import os
12+
import time
13+
import numpy as np
14+
15+
os.environ.setdefault("JAX_ENABLE_X64", "1")
16+
17+
import jax
18+
jax.config.update("jax_enable_x64", True)
19+
import jax.numpy as jnp
20+
21+
from colossus.cosmology import cosmology as col_cosmology
22+
from colossus.halo.concentration import concentration as col_concentration
23+
24+
import sys
25+
sys.path.insert(0, os.path.dirname(__file__))
26+
from ludlow16_jax import ludlow16_concentration_jax # noqa: E402
27+
28+
29+
PLANCK15 = dict(
30+
h=0.6774,
31+
Om0=0.3089,
32+
Ob0=0.0486,
33+
Tcmb0=2.7255,
34+
sigma8=0.8159,
35+
ns=0.9667,
36+
)
37+
38+
39+
def main():
40+
col_cosmology.setCosmology("planck15")
41+
42+
# ----------------------------------------------------------------
43+
# 1) grad — does jax.grad work through the concentration solver?
44+
# ----------------------------------------------------------------
45+
46+
@jax.jit
47+
def c_of_M(M):
48+
return ludlow16_concentration_jax(M, 0.5, **PLANCK15)
49+
50+
print("grad test:")
51+
M0 = jnp.float64(1.0e12)
52+
c0 = c_of_M(M0)
53+
g = jax.grad(c_of_M)(M0)
54+
print(f" c(1e12, z=0.5) = {float(c0):.4f}")
55+
print(f" dc/dM = {float(g):.4e}")
56+
57+
# Finite-difference check using colossus
58+
eps = 1.0e9
59+
c_low = float(col_concentration(1.0e12 - eps, "200c", 0.5, model="ludlow16"))
60+
c_high = float(col_concentration(1.0e12 + eps, "200c", 0.5, model="ludlow16"))
61+
g_fd = (c_high - c_low) / (2 * eps)
62+
print(f" dc/dM (fd) = {g_fd:.4e}")
63+
print(f" agreement: rel err = {abs(g - g_fd) / abs(g_fd):.2e}")
64+
print()
65+
66+
# ----------------------------------------------------------------
67+
# 2) vmap — batch many (M, z) pairs in one JIT trace
68+
# ----------------------------------------------------------------
69+
70+
@jax.jit
71+
def c_single(M, z):
72+
return ludlow16_concentration_jax(M, z, **PLANCK15)
73+
74+
batch_size = 32
75+
M_batch = jnp.geomspace(1.0e10, 1.0e14, batch_size)
76+
z_batch = jnp.linspace(0.1, 2.5, batch_size)
77+
78+
c_batched = jax.jit(jax.vmap(c_single))
79+
# warm up
80+
_ = c_batched(M_batch, z_batch).block_until_ready()
81+
82+
n_iter = 20
83+
t0 = time.perf_counter()
84+
for _ in range(n_iter):
85+
_ = c_batched(M_batch, z_batch).block_until_ready()
86+
t_jax_batch = (time.perf_counter() - t0) / n_iter
87+
88+
t0 = time.perf_counter()
89+
for _ in range(n_iter):
90+
for M, z in zip(np.array(M_batch), np.array(z_batch)):
91+
_ = col_concentration(float(M), "200c", float(z), model="ludlow16")
92+
t_col_batch = (time.perf_counter() - t0) / n_iter
93+
94+
print(f"Batch of {batch_size} concentration calls:")
95+
print(f" colossus (serial) : {t_col_batch*1e3:.2f} ms")
96+
print(f" jax vmap (post-jit) : {t_jax_batch*1e3:.2f} ms")
97+
print(f" speedup : {t_col_batch / t_jax_batch:.2f}x")
98+
print()
99+
100+
# ----------------------------------------------------------------
101+
# 3) Single-call wall time, including JIT compile vs post-compile
102+
# ----------------------------------------------------------------
103+
104+
fresh_jit = jax.jit(
105+
lambda M, z: ludlow16_concentration_jax(M, z, **PLANCK15)
106+
)
107+
t0 = time.perf_counter()
108+
_ = fresh_jit(1.0e12, 0.5).block_until_ready()
109+
t_compile = time.perf_counter() - t0
110+
print(f"JIT compile + first call: {t_compile*1e3:.1f} ms")
111+
112+
n_iter = 100
113+
t0 = time.perf_counter()
114+
for _ in range(n_iter):
115+
_ = fresh_jit(1.0e12, 0.5).block_until_ready()
116+
t_post = (time.perf_counter() - t0) / n_iter
117+
print(f"Post-compile single call: {t_post*1e3:.3f} ms")
118+
119+
120+
if __name__ == "__main__":
121+
main()

0 commit comments

Comments
 (0)