diff --git a/autofit/non_linear/fitness.py b/autofit/non_linear/fitness.py index f113651fa..d270c7885 100644 --- a/autofit/non_linear/fitness.py +++ b/autofit/non_linear/fitness.py @@ -42,6 +42,7 @@ def __init__( convert_to_chi_squared: bool = False, store_history: bool = False, use_jax_vmap : bool = False, + use_jax_jit : bool = False, batch_size : Optional[int] = None, iterations_per_quick_update: Optional[int] = None, background_quick_update: bool = False, @@ -118,11 +119,14 @@ def __init__( self.log_likelihood_history_list = [] self.use_jax_vmap = use_jax_vmap + self.use_jax_jit = use_jax_jit self._call = self.call if self.use_jax_vmap: self._call = self._vmap + elif self.use_jax_jit: + self._call = self._jit self.batch_size = batch_size self.iterations_per_quick_update = iterations_per_quick_update @@ -235,6 +239,9 @@ def call_wrap(self, parameters): figure_of_merit = self._call(parameters) + if self.use_jax_jit: + figure_of_merit = float(figure_of_merit) + if self.convert_to_chi_squared: log_likelihood = -0.5 * figure_of_merit else: @@ -382,15 +389,22 @@ def __call__(self, parameters, *kwargs): """ return self.call_wrap(parameters) - # def __getstate__(self): - # state = self.__dict__.copy() - # # Remove non-pickleable attributes - # state.pop('_call', None) - # state.pop('_grad', None) - # return state - # - # def __setstate__(self, state): - # self.__dict__.update(state) + def __getstate__(self): + state = self.__dict__.copy() + # Strip JAX-compiled callables: jax.jit / jax.vmap / jax.grad return + # functions tied to C++ XLA state that can't roundtrip through pickle. + # cached_property values lazily recompile on first access after unpickle. + for attr in ("_call", "_jit", "_vmap", "_grad"): + state.pop(attr, None) + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._call = self.call + if getattr(self, "use_jax_vmap", False): + self._call = self._vmap + elif getattr(self, "use_jax_jit", False): + self._call = self._jit @cached_property def _vmap(self): diff --git a/autofit/non_linear/search/nest/dynesty/search/abstract.py b/autofit/non_linear/search/nest/dynesty/search/abstract.py index de8d27e5e..bc53febed 100644 --- a/autofit/non_linear/search/nest/dynesty/search/abstract.py +++ b/autofit/non_linear/search/nest/dynesty/search/abstract.py @@ -53,6 +53,7 @@ def __init__( number_of_cores: int = 1, silence: bool = False, force_x1_cpu: bool = False, + use_jax_jit: bool = True, session: Optional[sa.orm.Session] = None, **kwargs, ): @@ -117,6 +118,7 @@ def __init__( self.maxcall = maxcall self.force_x1_cpu = force_x1_cpu + self.use_jax_jit = use_jax_jit self.logger.debug(f"Creating {self.__class__.__name__} Search") @@ -179,6 +181,7 @@ def _fit( paths=self.paths, fom_is_log_likelihood=True, resample_figure_of_merit=-1.0e99, + use_jax_jit=getattr(analysis, "_use_jax", False) and self.use_jax_jit, ) if not isinstance(self.paths, NullPaths): @@ -225,13 +228,22 @@ def _fit( except RuntimeError: if not checkpoint_exists: - self.logger.info( - """ - Your operating system does not support Python multiprocessing. - - A single CPU non-multiprocessing Dynesty run is being performed. - """ - ) + if getattr(analysis, "_use_jax", False): + self.logger.info( + "Running Dynesty with JAX-jitted likelihood (single CPU, no pool)." + ) + elif self.force_x1_cpu: + self.logger.info( + "Running Dynesty single-CPU per `force_x1_cpu=True` (no pool)." + ) + else: + self.logger.info( + """ + Your operating system does not support Python multiprocessing. + + A single CPU non-multiprocessing Dynesty run is being performed. + """ + ) search_internal = self.search_internal_from( model=model, diff --git a/test_autofit/non_linear/test_dict.py b/test_autofit/non_linear/test_dict.py index 5a8034685..04eeb044c 100644 --- a/test_autofit/non_linear/test_dict.py +++ b/test_autofit/non_linear/test_dict.py @@ -43,6 +43,7 @@ def make_dynesty_dict(): "slices": 5, "unique_tag": None, "update_interval": None, + "use_jax_jit": True, "walks": 5, }, "class_path": "autofit.non_linear.search.nest.dynesty.search.static.DynestyStatic", diff --git a/test_autofit/non_linear/test_fitness_jax_dispatch.py b/test_autofit/non_linear/test_fitness_jax_dispatch.py new file mode 100644 index 000000000..23243c10c --- /dev/null +++ b/test_autofit/non_linear/test_fitness_jax_dispatch.py @@ -0,0 +1,67 @@ +import pickle + +import numpy as np + +import autofit as af +from autofit.non_linear.fitness import Fitness + + +def _make_fitness(**kwargs): + model = af.Model(af.ex.Gaussian) + data = np.ones(20) + noise_map = np.ones(20) * 0.1 + analysis = af.ex.Analysis(data=data, noise_map=noise_map) + return Fitness(model=model, analysis=analysis, **kwargs) + + +def test_default_dispatch_is_call(): + fitness = _make_fitness() + # `self.call` produces a fresh bound method each access — compare the + # underlying function instead of the bound-method instance. + assert fitness._call.__func__ is Fitness.call + assert fitness.use_jax_jit is False + assert fitness.use_jax_vmap is False + + +def test_jit_dispatch_sets_call_to_jit(): + fitness = _make_fitness(use_jax_jit=True) + assert fitness.use_jax_jit is True + assert fitness._call is fitness._jit + + +def test_vmap_takes_precedence_over_jit(): + fitness = _make_fitness(use_jax_jit=True, use_jax_vmap=True) + assert fitness._call is fitness._vmap + + +def test_pickle_strips_jax_cached_attrs(): + """ + Dynesty's checkpoint writes pickle the loglikelihood. JAX-compiled + callables (jax.jit / jax.vmap / jax.grad) carry C++ XLA state that + cannot roundtrip through pickle. ``Fitness.__getstate__`` must drop + them; ``Fitness.__setstate__`` re-derives the dispatch on resume. + """ + fitness = _make_fitness(use_jax_jit=True) + + state = fitness.__getstate__() + assert "_call" not in state + assert "_jit" not in state + assert "_vmap" not in state + assert "_grad" not in state + + blob = pickle.dumps(fitness) + restored = pickle.loads(blob) + + assert restored.use_jax_jit is True + assert restored._call is restored._jit + + +def test_pickle_default_path_unchanged(): + fitness = _make_fitness() + + blob = pickle.dumps(fitness) + restored = pickle.loads(blob) + + assert restored.use_jax_jit is False + assert restored.use_jax_vmap is False + assert restored._call.__func__ is Fitness.call