Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions autofit/non_linear/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
convert_to_chi_squared: bool = False,
store_history: bool = False,
use_jax_vmap : bool = False,
use_jax_jit : bool = False,
batch_size : Optional[int] = None,
iterations_per_quick_update: Optional[int] = None,
background_quick_update: bool = False,
Expand Down Expand Up @@ -118,11 +119,14 @@ def __init__(
self.log_likelihood_history_list = []

self.use_jax_vmap = use_jax_vmap
self.use_jax_jit = use_jax_jit

self._call = self.call

if self.use_jax_vmap:
self._call = self._vmap
elif self.use_jax_jit:
self._call = self._jit

self.batch_size = batch_size
self.iterations_per_quick_update = iterations_per_quick_update
Expand Down Expand Up @@ -235,6 +239,9 @@ def call_wrap(self, parameters):

figure_of_merit = self._call(parameters)

if self.use_jax_jit:
figure_of_merit = float(figure_of_merit)

if self.convert_to_chi_squared:
log_likelihood = -0.5 * figure_of_merit
else:
Expand Down Expand Up @@ -382,15 +389,22 @@ def __call__(self, parameters, *kwargs):
"""
return self.call_wrap(parameters)

# def __getstate__(self):
# state = self.__dict__.copy()
# # Remove non-pickleable attributes
# state.pop('_call', None)
# state.pop('_grad', None)
# return state
#
# def __setstate__(self, state):
# self.__dict__.update(state)
def __getstate__(self):
state = self.__dict__.copy()
# Strip JAX-compiled callables: jax.jit / jax.vmap / jax.grad return
# functions tied to C++ XLA state that can't roundtrip through pickle.
# cached_property values lazily recompile on first access after unpickle.
for attr in ("_call", "_jit", "_vmap", "_grad"):
state.pop(attr, None)
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._call = self.call
if getattr(self, "use_jax_vmap", False):
self._call = self._vmap
elif getattr(self, "use_jax_jit", False):
self._call = self._jit

@cached_property
def _vmap(self):
Expand Down
26 changes: 19 additions & 7 deletions autofit/non_linear/search/nest/dynesty/search/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
number_of_cores: int = 1,
silence: bool = False,
force_x1_cpu: bool = False,
use_jax_jit: bool = True,
session: Optional[sa.orm.Session] = None,
**kwargs,
):
Expand Down Expand Up @@ -117,6 +118,7 @@ def __init__(

self.maxcall = maxcall
self.force_x1_cpu = force_x1_cpu
self.use_jax_jit = use_jax_jit

self.logger.debug(f"Creating {self.__class__.__name__} Search")

Expand Down Expand Up @@ -179,6 +181,7 @@ def _fit(
paths=self.paths,
fom_is_log_likelihood=True,
resample_figure_of_merit=-1.0e99,
use_jax_jit=getattr(analysis, "_use_jax", False) and self.use_jax_jit,
)

if not isinstance(self.paths, NullPaths):
Expand Down Expand Up @@ -225,13 +228,22 @@ def _fit(

except RuntimeError:
if not checkpoint_exists:
self.logger.info(
"""
Your operating system does not support Python multiprocessing.

A single CPU non-multiprocessing Dynesty run is being performed.
"""
)
if getattr(analysis, "_use_jax", False):
self.logger.info(
"Running Dynesty with JAX-jitted likelihood (single CPU, no pool)."
)
elif self.force_x1_cpu:
self.logger.info(
"Running Dynesty single-CPU per `force_x1_cpu=True` (no pool)."
)
else:
self.logger.info(
"""
Your operating system does not support Python multiprocessing.

A single CPU non-multiprocessing Dynesty run is being performed.
"""
)

search_internal = self.search_internal_from(
model=model,
Expand Down
1 change: 1 addition & 0 deletions test_autofit/non_linear/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def make_dynesty_dict():
"slices": 5,
"unique_tag": None,
"update_interval": None,
"use_jax_jit": True,
"walks": 5,
},
"class_path": "autofit.non_linear.search.nest.dynesty.search.static.DynestyStatic",
Expand Down
67 changes: 67 additions & 0 deletions test_autofit/non_linear/test_fitness_jax_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pickle

import numpy as np

import autofit as af
from autofit.non_linear.fitness import Fitness


def _make_fitness(**kwargs):
model = af.Model(af.ex.Gaussian)
data = np.ones(20)
noise_map = np.ones(20) * 0.1
analysis = af.ex.Analysis(data=data, noise_map=noise_map)
return Fitness(model=model, analysis=analysis, **kwargs)


def test_default_dispatch_is_call():
fitness = _make_fitness()
# `self.call` produces a fresh bound method each access — compare the
# underlying function instead of the bound-method instance.
assert fitness._call.__func__ is Fitness.call
assert fitness.use_jax_jit is False
assert fitness.use_jax_vmap is False


def test_jit_dispatch_sets_call_to_jit():
fitness = _make_fitness(use_jax_jit=True)
assert fitness.use_jax_jit is True
assert fitness._call is fitness._jit


def test_vmap_takes_precedence_over_jit():
fitness = _make_fitness(use_jax_jit=True, use_jax_vmap=True)
assert fitness._call is fitness._vmap


def test_pickle_strips_jax_cached_attrs():
"""
Dynesty's checkpoint writes pickle the loglikelihood. JAX-compiled
callables (jax.jit / jax.vmap / jax.grad) carry C++ XLA state that
cannot roundtrip through pickle. ``Fitness.__getstate__`` must drop
them; ``Fitness.__setstate__`` re-derives the dispatch on resume.
"""
fitness = _make_fitness(use_jax_jit=True)

state = fitness.__getstate__()
assert "_call" not in state
assert "_jit" not in state
assert "_vmap" not in state
assert "_grad" not in state

blob = pickle.dumps(fitness)
restored = pickle.loads(blob)

assert restored.use_jax_jit is True
assert restored._call is restored._jit


def test_pickle_default_path_unchanged():
fitness = _make_fitness()

blob = pickle.dumps(fitness)
restored = pickle.loads(blob)

assert restored.use_jax_jit is False
assert restored.use_jax_vmap is False
assert restored._call.__func__ is Fitness.call
Loading