|
| 1 | +""" |
| 2 | +Benchmark the JAX prototype against colossus, and exercise grad + vmap |
| 3 | +to confirm the JAX path supports differentiation and vectorisation — |
| 4 | +the two capabilities the pure_callback rules out. |
| 5 | +
|
| 6 | +Run with: |
| 7 | + NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib \ |
| 8 | + JAX_ENABLE_X64=1 python docs/research/nfw_ludlow16_jax/bench.py |
| 9 | +""" |
| 10 | + |
| 11 | +import os |
| 12 | +import time |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +os.environ.setdefault("JAX_ENABLE_X64", "1") |
| 16 | + |
| 17 | +import jax |
| 18 | +jax.config.update("jax_enable_x64", True) |
| 19 | +import jax.numpy as jnp |
| 20 | + |
| 21 | +from colossus.cosmology import cosmology as col_cosmology |
| 22 | +from colossus.halo.concentration import concentration as col_concentration |
| 23 | + |
| 24 | +import sys |
| 25 | +sys.path.insert(0, os.path.dirname(__file__)) |
| 26 | +from ludlow16_jax import ludlow16_concentration_jax # noqa: E402 |
| 27 | + |
| 28 | + |
| 29 | +PLANCK15 = dict( |
| 30 | + h=0.6774, |
| 31 | + Om0=0.3089, |
| 32 | + Ob0=0.0486, |
| 33 | + Tcmb0=2.7255, |
| 34 | + sigma8=0.8159, |
| 35 | + ns=0.9667, |
| 36 | +) |
| 37 | + |
| 38 | + |
| 39 | +def main(): |
| 40 | + col_cosmology.setCosmology("planck15") |
| 41 | + |
| 42 | + # ---------------------------------------------------------------- |
| 43 | + # 1) grad — does jax.grad work through the concentration solver? |
| 44 | + # ---------------------------------------------------------------- |
| 45 | + |
| 46 | + @jax.jit |
| 47 | + def c_of_M(M): |
| 48 | + return ludlow16_concentration_jax(M, 0.5, **PLANCK15) |
| 49 | + |
| 50 | + print("grad test:") |
| 51 | + M0 = jnp.float64(1.0e12) |
| 52 | + c0 = c_of_M(M0) |
| 53 | + g = jax.grad(c_of_M)(M0) |
| 54 | + print(f" c(1e12, z=0.5) = {float(c0):.4f}") |
| 55 | + print(f" dc/dM = {float(g):.4e}") |
| 56 | + |
| 57 | + # Finite-difference check using colossus |
| 58 | + eps = 1.0e9 |
| 59 | + c_low = float(col_concentration(1.0e12 - eps, "200c", 0.5, model="ludlow16")) |
| 60 | + c_high = float(col_concentration(1.0e12 + eps, "200c", 0.5, model="ludlow16")) |
| 61 | + g_fd = (c_high - c_low) / (2 * eps) |
| 62 | + print(f" dc/dM (fd) = {g_fd:.4e}") |
| 63 | + print(f" agreement: rel err = {abs(g - g_fd) / abs(g_fd):.2e}") |
| 64 | + print() |
| 65 | + |
| 66 | + # ---------------------------------------------------------------- |
| 67 | + # 2) vmap — batch many (M, z) pairs in one JIT trace |
| 68 | + # ---------------------------------------------------------------- |
| 69 | + |
| 70 | + @jax.jit |
| 71 | + def c_single(M, z): |
| 72 | + return ludlow16_concentration_jax(M, z, **PLANCK15) |
| 73 | + |
| 74 | + batch_size = 32 |
| 75 | + M_batch = jnp.geomspace(1.0e10, 1.0e14, batch_size) |
| 76 | + z_batch = jnp.linspace(0.1, 2.5, batch_size) |
| 77 | + |
| 78 | + c_batched = jax.jit(jax.vmap(c_single)) |
| 79 | + # warm up |
| 80 | + _ = c_batched(M_batch, z_batch).block_until_ready() |
| 81 | + |
| 82 | + n_iter = 20 |
| 83 | + t0 = time.perf_counter() |
| 84 | + for _ in range(n_iter): |
| 85 | + _ = c_batched(M_batch, z_batch).block_until_ready() |
| 86 | + t_jax_batch = (time.perf_counter() - t0) / n_iter |
| 87 | + |
| 88 | + t0 = time.perf_counter() |
| 89 | + for _ in range(n_iter): |
| 90 | + for M, z in zip(np.array(M_batch), np.array(z_batch)): |
| 91 | + _ = col_concentration(float(M), "200c", float(z), model="ludlow16") |
| 92 | + t_col_batch = (time.perf_counter() - t0) / n_iter |
| 93 | + |
| 94 | + print(f"Batch of {batch_size} concentration calls:") |
| 95 | + print(f" colossus (serial) : {t_col_batch*1e3:.2f} ms") |
| 96 | + print(f" jax vmap (post-jit) : {t_jax_batch*1e3:.2f} ms") |
| 97 | + print(f" speedup : {t_col_batch / t_jax_batch:.2f}x") |
| 98 | + print() |
| 99 | + |
| 100 | + # ---------------------------------------------------------------- |
| 101 | + # 3) Single-call wall time, including JIT compile vs post-compile |
| 102 | + # ---------------------------------------------------------------- |
| 103 | + |
| 104 | + fresh_jit = jax.jit( |
| 105 | + lambda M, z: ludlow16_concentration_jax(M, z, **PLANCK15) |
| 106 | + ) |
| 107 | + t0 = time.perf_counter() |
| 108 | + _ = fresh_jit(1.0e12, 0.5).block_until_ready() |
| 109 | + t_compile = time.perf_counter() - t0 |
| 110 | + print(f"JIT compile + first call: {t_compile*1e3:.1f} ms") |
| 111 | + |
| 112 | + n_iter = 100 |
| 113 | + t0 = time.perf_counter() |
| 114 | + for _ in range(n_iter): |
| 115 | + _ = fresh_jit(1.0e12, 0.5).block_until_ready() |
| 116 | + t_post = (time.perf_counter() - t0) / n_iter |
| 117 | + print(f"Post-compile single call: {t_post*1e3:.3f} ms") |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + main() |
0 commit comments