Skip to content

Commit e32ab5b

Browse files
Jammy2211claude
authored andcommitted
feat(dynesty): support JAX-jitted likelihoods via use_jax_jit
dynesty 2.1.5's NestedSampler has no `vectorized` parameter — it calls the likelihood one sample at a time, so Nautilus's vmap-batching approach doesn't apply. Use `jax.jit` on its own instead: JAX's compiled-function cache reuses the compiled likelihood across calls, giving a fast CPU/GPU evaluation path for nested sampling without requiring autodiff. Changes: - `Fitness.__init__` accepts `use_jax_jit: bool = False`. When set, `self._call = self._jit` (parallel to the existing `_vmap` dispatch). vmap takes precedence if both flags are somehow set. - `Fitness.call_wrap` casts the jit-path return value to a Python `float`. dynesty's `logz` accumulators and HDF5 savestate require numpy/Python scalars, not raw JAX `Array`s. The vmap path is untouched — Nautilus accepts JAX arrays at its `vectorized=True` interface. - `Fitness.__getstate__` / `__setstate__` re-enabled (previously commented out) and extended to strip `_call`, `_jit`, `_vmap`, `_grad` from the pickle. dynesty's `run_nested(checkpoint_file=...)` pickles the loglikelihood; JAX-compiled callables hold C++ XLA state that doesn't roundtrip through pickle. `__setstate__` re-derives the dispatch on resume so the cached_property recompiles lazily on the first call. - `AbstractDynesty.__init__` accepts `use_jax_jit: bool = True`. In `_fit`, the upfront `Fitness(...)` construction passes `use_jax_jit=(analysis._use_jax and self.use_jax_jit)`. Default-on when JAX is enabled; user can disable via the search-class flag. - The existing no-pool fallback (triggered when `force_x1_cpu` or `analysis._use_jax`) now branches the log message three ways: JAX path, force_x1_cpu, OS-multiprocessing fallback. The original message wrongly attributed JAX/force_x1_cpu fallbacks to "OS does not support multiprocessing". Tests: 5 new unit tests in `test_fitness_jax_dispatch.py` cover the dispatch logic and pickle round-trip. They do not import jax (per project policy: library unit tests stay numpy-only). `test_dict` fixture updated for the new `use_jax_jit` arg on `DynestyStatic`. Verification: companion script `Dynesty_jax.py` will land in `autofit_workspace_test`; runs end-to-end with `log_Z ≈ -54, dlogz < 0.5` on the standard 1D Gaussian dataset. `Nautilus_jax.py` remains green (vmap path unchanged). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent cd6216c commit e32ab5b

4 files changed

Lines changed: 110 additions & 16 deletions

File tree

autofit/non_linear/fitness.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
convert_to_chi_squared: bool = False,
4343
store_history: bool = False,
4444
use_jax_vmap : bool = False,
45+
use_jax_jit : bool = False,
4546
batch_size : Optional[int] = None,
4647
iterations_per_quick_update: Optional[int] = None,
4748
background_quick_update: bool = False,
@@ -118,11 +119,14 @@ def __init__(
118119
self.log_likelihood_history_list = []
119120

120121
self.use_jax_vmap = use_jax_vmap
122+
self.use_jax_jit = use_jax_jit
121123

122124
self._call = self.call
123125

124126
if self.use_jax_vmap:
125127
self._call = self._vmap
128+
elif self.use_jax_jit:
129+
self._call = self._jit
126130

127131
self.batch_size = batch_size
128132
self.iterations_per_quick_update = iterations_per_quick_update
@@ -235,6 +239,9 @@ def call_wrap(self, parameters):
235239

236240
figure_of_merit = self._call(parameters)
237241

242+
if self.use_jax_jit:
243+
figure_of_merit = float(figure_of_merit)
244+
238245
if self.convert_to_chi_squared:
239246
log_likelihood = -0.5 * figure_of_merit
240247
else:
@@ -382,15 +389,22 @@ def __call__(self, parameters, *kwargs):
382389
"""
383390
return self.call_wrap(parameters)
384391

385-
# def __getstate__(self):
386-
# state = self.__dict__.copy()
387-
# # Remove non-pickleable attributes
388-
# state.pop('_call', None)
389-
# state.pop('_grad', None)
390-
# return state
391-
#
392-
# def __setstate__(self, state):
393-
# self.__dict__.update(state)
392+
def __getstate__(self):
393+
state = self.__dict__.copy()
394+
# Strip JAX-compiled callables: jax.jit / jax.vmap / jax.grad return
395+
# functions tied to C++ XLA state that can't roundtrip through pickle.
396+
# cached_property values lazily recompile on first access after unpickle.
397+
for attr in ("_call", "_jit", "_vmap", "_grad"):
398+
state.pop(attr, None)
399+
return state
400+
401+
def __setstate__(self, state):
402+
self.__dict__.update(state)
403+
self._call = self.call
404+
if getattr(self, "use_jax_vmap", False):
405+
self._call = self._vmap
406+
elif getattr(self, "use_jax_jit", False):
407+
self._call = self._jit
394408

395409
@cached_property
396410
def _vmap(self):

autofit/non_linear/search/nest/dynesty/search/abstract.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
number_of_cores: int = 1,
5454
silence: bool = False,
5555
force_x1_cpu: bool = False,
56+
use_jax_jit: bool = True,
5657
session: Optional[sa.orm.Session] = None,
5758
**kwargs,
5859
):
@@ -117,6 +118,7 @@ def __init__(
117118

118119
self.maxcall = maxcall
119120
self.force_x1_cpu = force_x1_cpu
121+
self.use_jax_jit = use_jax_jit
120122

121123
self.logger.debug(f"Creating {self.__class__.__name__} Search")
122124

@@ -179,6 +181,7 @@ def _fit(
179181
paths=self.paths,
180182
fom_is_log_likelihood=True,
181183
resample_figure_of_merit=-1.0e99,
184+
use_jax_jit=getattr(analysis, "_use_jax", False) and self.use_jax_jit,
182185
)
183186

184187
if not isinstance(self.paths, NullPaths):
@@ -225,13 +228,22 @@ def _fit(
225228

226229
except RuntimeError:
227230
if not checkpoint_exists:
228-
self.logger.info(
229-
"""
230-
Your operating system does not support Python multiprocessing.
231-
232-
A single CPU non-multiprocessing Dynesty run is being performed.
233-
"""
234-
)
231+
if getattr(analysis, "_use_jax", False):
232+
self.logger.info(
233+
"Running Dynesty with JAX-jitted likelihood (single CPU, no pool)."
234+
)
235+
elif self.force_x1_cpu:
236+
self.logger.info(
237+
"Running Dynesty single-CPU per `force_x1_cpu=True` (no pool)."
238+
)
239+
else:
240+
self.logger.info(
241+
"""
242+
Your operating system does not support Python multiprocessing.
243+
244+
A single CPU non-multiprocessing Dynesty run is being performed.
245+
"""
246+
)
235247

236248
search_internal = self.search_internal_from(
237249
model=model,

test_autofit/non_linear/test_dict.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def make_dynesty_dict():
4343
"slices": 5,
4444
"unique_tag": None,
4545
"update_interval": None,
46+
"use_jax_jit": True,
4647
"walks": 5,
4748
},
4849
"class_path": "autofit.non_linear.search.nest.dynesty.search.static.DynestyStatic",
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import pickle
2+
3+
import numpy as np
4+
5+
import autofit as af
6+
from autofit.non_linear.fitness import Fitness
7+
8+
9+
def _make_fitness(**kwargs):
10+
model = af.Model(af.ex.Gaussian)
11+
data = np.ones(20)
12+
noise_map = np.ones(20) * 0.1
13+
analysis = af.ex.Analysis(data=data, noise_map=noise_map)
14+
return Fitness(model=model, analysis=analysis, **kwargs)
15+
16+
17+
def test_default_dispatch_is_call():
18+
fitness = _make_fitness()
19+
# `self.call` produces a fresh bound method each access — compare the
20+
# underlying function instead of the bound-method instance.
21+
assert fitness._call.__func__ is Fitness.call
22+
assert fitness.use_jax_jit is False
23+
assert fitness.use_jax_vmap is False
24+
25+
26+
def test_jit_dispatch_sets_call_to_jit():
27+
fitness = _make_fitness(use_jax_jit=True)
28+
assert fitness.use_jax_jit is True
29+
assert fitness._call is fitness._jit
30+
31+
32+
def test_vmap_takes_precedence_over_jit():
33+
fitness = _make_fitness(use_jax_jit=True, use_jax_vmap=True)
34+
assert fitness._call is fitness._vmap
35+
36+
37+
def test_pickle_strips_jax_cached_attrs():
38+
"""
39+
Dynesty's checkpoint writes pickle the loglikelihood. JAX-compiled
40+
callables (jax.jit / jax.vmap / jax.grad) carry C++ XLA state that
41+
cannot roundtrip through pickle. ``Fitness.__getstate__`` must drop
42+
them; ``Fitness.__setstate__`` re-derives the dispatch on resume.
43+
"""
44+
fitness = _make_fitness(use_jax_jit=True)
45+
46+
state = fitness.__getstate__()
47+
assert "_call" not in state
48+
assert "_jit" not in state
49+
assert "_vmap" not in state
50+
assert "_grad" not in state
51+
52+
blob = pickle.dumps(fitness)
53+
restored = pickle.loads(blob)
54+
55+
assert restored.use_jax_jit is True
56+
assert restored._call is restored._jit
57+
58+
59+
def test_pickle_default_path_unchanged():
60+
fitness = _make_fitness()
61+
62+
blob = pickle.dumps(fitness)
63+
restored = pickle.loads(blob)
64+
65+
assert restored.use_jax_jit is False
66+
assert restored.use_jax_vmap is False
67+
assert restored._call.__func__ is Fitness.call

0 commit comments

Comments
 (0)