Commit e32ab5b
feat(dynesty): support JAX-jitted likelihoods via use_jax_jit
dynesty 2.1.5's NestedSampler has no `vectorized` parameter — it calls
the likelihood one sample at a time, so Nautilus's vmap-batching
approach doesn't apply. Use `jax.jit` on its own instead: JAX's
compiled-function cache reuses the compiled likelihood across calls,
giving a fast CPU/GPU evaluation path for nested sampling without
requiring autodiff.
Changes:
- `Fitness.__init__` accepts `use_jax_jit: bool = False`. When set,
`self._call = self._jit` (parallel to the existing `_vmap` dispatch).
vmap takes precedence if both flags are somehow set.
- `Fitness.call_wrap` casts the jit-path return value to a Python
`float`. dynesty's `logz` accumulators and HDF5 savestate require
numpy/Python scalars, not raw JAX `Array`s. The vmap path is
untouched — Nautilus accepts JAX arrays at its `vectorized=True`
interface.
- `Fitness.__getstate__` / `__setstate__` re-enabled (previously
commented out) and extended to strip `_call`, `_jit`, `_vmap`,
`_grad` from the pickle. dynesty's `run_nested(checkpoint_file=...)`
pickles the loglikelihood; JAX-compiled callables hold C++ XLA state
that doesn't roundtrip through pickle. `__setstate__` re-derives the
dispatch on resume so the cached_property recompiles lazily on the
first call.
- `AbstractDynesty.__init__` accepts `use_jax_jit: bool = True`. In
`_fit`, the upfront `Fitness(...)` construction passes
`use_jax_jit=(analysis._use_jax and self.use_jax_jit)`. Default-on
when JAX is enabled; user can disable via the search-class flag.
- The existing no-pool fallback (triggered when `force_x1_cpu` or
`analysis._use_jax`) now branches the log message three ways: JAX
path, force_x1_cpu, OS-multiprocessing fallback. The original
message wrongly attributed JAX/force_x1_cpu fallbacks to "OS does
not support multiprocessing".
Tests: 5 new unit tests in `test_fitness_jax_dispatch.py` cover the
dispatch logic and pickle round-trip. They do not import jax (per
project policy: library unit tests stay numpy-only). `test_dict`
fixture updated for the new `use_jax_jit` arg on `DynestyStatic`.
Verification: companion script `Dynesty_jax.py` will land in
`autofit_workspace_test`; runs end-to-end with
`log_Z ≈ -54, dlogz < 0.5` on the standard 1D Gaussian dataset.
`Nautilus_jax.py` remains green (vmap path unchanged).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>1 parent cd6216c commit e32ab5b
4 files changed
Lines changed: 110 additions & 16 deletions
File tree
- autofit/non_linear
- search/nest/dynesty/search
- test_autofit/non_linear
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
| 45 | + | |
45 | 46 | | |
46 | 47 | | |
47 | 48 | | |
| |||
118 | 119 | | |
119 | 120 | | |
120 | 121 | | |
| 122 | + | |
121 | 123 | | |
122 | 124 | | |
123 | 125 | | |
124 | 126 | | |
125 | 127 | | |
| 128 | + | |
| 129 | + | |
126 | 130 | | |
127 | 131 | | |
128 | 132 | | |
| |||
235 | 239 | | |
236 | 240 | | |
237 | 241 | | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
238 | 245 | | |
239 | 246 | | |
240 | 247 | | |
| |||
382 | 389 | | |
383 | 390 | | |
384 | 391 | | |
385 | | - | |
386 | | - | |
387 | | - | |
388 | | - | |
389 | | - | |
390 | | - | |
391 | | - | |
392 | | - | |
393 | | - | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
394 | 408 | | |
395 | 409 | | |
396 | 410 | | |
| |||
Lines changed: 19 additions & 7 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
53 | 53 | | |
54 | 54 | | |
55 | 55 | | |
| 56 | + | |
56 | 57 | | |
57 | 58 | | |
58 | 59 | | |
| |||
117 | 118 | | |
118 | 119 | | |
119 | 120 | | |
| 121 | + | |
120 | 122 | | |
121 | 123 | | |
122 | 124 | | |
| |||
179 | 181 | | |
180 | 182 | | |
181 | 183 | | |
| 184 | + | |
182 | 185 | | |
183 | 186 | | |
184 | 187 | | |
| |||
225 | 228 | | |
226 | 229 | | |
227 | 230 | | |
228 | | - | |
229 | | - | |
230 | | - | |
231 | | - | |
232 | | - | |
233 | | - | |
234 | | - | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
235 | 247 | | |
236 | 248 | | |
237 | 249 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
0 commit comments