Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynesty: new method to save only the raw dynesty sample results #1331

Merged
merged 14 commits into from
Mar 21, 2024
4 changes: 3 additions & 1 deletion doc/example/sampler_study.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The internal `dynesty` sampler can be saved and restored, for post-sampling analysis. For example, pyPESTO stores resampled MCMC-like samples from the `dynesty` sampler by default. The following code shows how to save and load the internal dynesty sampler, to facilitate post-sampling analysis of both the resampled and original chains. First, we save the internal sampler."
"The internal `dynesty` sampler can be saved and restored, for post-sampling analysis. For example, pyPESTO stores resampled MCMC-like samples from the `dynesty` sampler by default. The following code shows how to save and load the internal dynesty sampler, to facilitate post-sampling analysis of both the resampled and original chains. N.B.: when working across different computers, you might prefer to work with the raw sample results via `pypesto.sample.dynesty.save_raw_results` and `load_raw_results`.",
"\n",
"First, we save the internal sampler."
]
},
{
Expand Down
181 changes: 128 additions & 53 deletions pypesto/sample/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,28 @@

from __future__ import annotations

import importlib
import logging
import warnings
from typing import List, Union

import cloudpickle # noqa: S403
import numpy as np

from ..C import OBJECTIVE_NEGLOGLIKE, OBJECTIVE_NEGLOGPOST
from ..problem import Problem
from ..result import McmcPtResult
from .sampler import Sampler, SamplerImportError

dynesty_pickle = cloudpickle
if importlib.util.find_spec("dynesty") is None:
dynesty = type("", (), {})()
dynesty.results = type("", (), {})()
dynesty.results.Results = None
else:
import dynesty

dynesty.utils.pickle_module = dynesty_pickle
logger = logging.getLogger(__name__)


Expand All @@ -47,7 +59,7 @@ class DynestySampler(Sampler):
To work with the original samples, modify the results object with
`pypesto_result.sample_result = sampler.get_original_samples()`, where
`sampler` is an instance of `pypesto.sample.DynestySampler`. The original
dynesty results object is available at `sampler.results`.
dynesty results object is available at `sampler.raw_results`.

NB: the dynesty samplers can be customized significantly, by providing
`sampler_args` and `run_args` to your `pypesto.sample.DynestySampler()`
Expand Down Expand Up @@ -79,13 +91,14 @@ def __init__(
dynamic:
Whether to use dynamic or static nested sampling.
objective_type:
The objective to optimize (as defined in pypesto.problem). Either "neglogpost" or
"negloglike". If "neglogpost", x_priors have to be defined in the problem.
The objective to optimize (as defined in `pypesto.problem`). Either
`pypesto.C.OBJECTIVE_NEGLOGLIKE` or
`pypesto.C.OBJECTIVE_NEGLOGPOST`. If
`pypesto.C.OBJECTIVE_NEGLOGPOST`, then `x_priors` have to
be defined in the problem.
"""
# check dependencies
import dynesty

setup_dynesty()
if importlib.util.find_spec("dynesty") is None:
raise SamplerImportError("dynesty")

super().__init__()

Expand Down Expand Up @@ -152,13 +165,9 @@ def loglikelihood(self, x):
def initialize(
self,
problem: Problem,
x0: Union[np.ndarray, List[np.ndarray]],
x0: Union[np.ndarray, List[np.ndarray]] = None,
) -> None:
"""Initialize the sampler."""
import dynesty

setup_dynesty()

self.problem = problem

sampler_class = dynesty.NestedSampler
Expand Down Expand Up @@ -212,7 +221,22 @@ def sample(self, n_samples: int, beta: float = None) -> None:
)

self.sampler.run_nested(**self.run_args)
self.results = self.sampler.results

@property
def results(self):
"""Deprecated in favor of `raw_results`."""
warnings.warn(
"Accessing dynesty results via `sampler.results` is "
"deprecated. Please use `sampler.raw_results` instead.",
DeprecationWarning,
stacklevel=1,
)
return self.raw_results

@property
def raw_results(self):
"""Get the raw dynesty results."""
return self.sampler.results

def save_internal_sampler(self, filename: str) -> None:
"""Save the state of the internal dynesty sampler.
Expand All @@ -225,10 +249,6 @@ def save_internal_sampler(self, filename: str) -> None:
filename:
The internal sampler will be saved here.
"""
import dynesty

setup_dynesty()

dynesty.utils.save_sampler(
sampler=self.sampler,
fname=filename,
Expand All @@ -242,11 +262,11 @@ def restore_internal_sampler(self, filename: str) -> None:
filename:
The internal sampler will be saved here.
"""
import dynesty

setup_dynesty()

self.sampler = dynesty.utils.restore_sampler(fname=filename)
pool = self.sampler_args.get("pool", None)
self.sampler = dynesty.utils.restore_sampler(
fname=filename,
pool=pool,
)

def get_original_samples(self) -> McmcPtResult:
"""Get the samples into the fitting pypesto format.
Expand All @@ -255,7 +275,7 @@ def get_original_samples(self) -> McmcPtResult:
-------
The pyPESTO sample result.
"""
return get_original_dynesty_samples(sampler=self.sampler)
return get_original_dynesty_samples(sampler=self)

def get_samples(self) -> McmcPtResult:
"""Get MCMC-like samples into the fitting pypesto format.
Expand All @@ -264,25 +284,87 @@ def get_samples(self) -> McmcPtResult:
-------
The pyPESTO sample result.
"""
return get_mcmc_like_dynesty_samples(sampler=self.sampler)
return get_mcmc_like_dynesty_samples(sampler=self)


def _get_raw_results(
sampler: DynestySampler,
raw_results: dynesty.result.Results,
) -> dynesty.results.Results:
if (sampler is None) == (raw_results is None):
raise ValueError(
"Please supply exactly one of `sampler` or `raw_results`."
)

if raw_results is not None:
return raw_results

if not isinstance(sampler, DynestySampler):
raise ValueError(
"Please provide a pyPESTO `DynestySampler` if using "
"the `sampler` argument of this method."
)

return sampler.raw_results


def get_original_dynesty_samples(sampler) -> McmcPtResult:
def save_raw_results(sampler: DynestySampler, filename: str) -> None:
"""Save dynesty sampler results to file.

Restoring the dynesty sampler on a different computer than the one that
samples were generated is problematic (e.g. an AMICI model might get
compiled automatically). This method should avoid that, by only saving
the results.

Parameters
----------
sampler:
The pyPESTO `DynestySampler` object used during sampling.
filename:
The file where the results will be saved.
"""
raw_results = _get_raw_results(sampler=sampler, raw_results=None)
with open(filename, "wb") as f:
dynesty_pickle.dump(raw_results, f)


def load_raw_results(filename: str) -> dynesty.results.Results:
"""Load dynesty sample results from file.

Parameters
----------
filename:
The file where the results will be loaded from.
"""
with open(filename, "rb") as f:
raw_results = dynesty_pickle.load(f)
return raw_results


def get_original_dynesty_samples(
sampler: DynestySampler = None,
raw_results: dynesty.results.Results = None,
) -> McmcPtResult:
"""Get original dynesty samples.

Only one of `sampler` or `raw_results` should be provided.

Parameters
----------
sampler:
The (internal!) dynesty sampler. See
`pypesto.sample.DynestySampler.__init__`, specifically the
`save_internal` argument, for more details.
The pyPESTO `DynestySampler` object with sampling results.
raw_results:
The raw results. See :func:`save_raw_results` and
:func:`load_raw_results`.

Returns
-------
The sample result.
"""
trace_x = np.array([sampler.results.samples])
trace_neglogpost = -np.array([sampler.results.logl])
raw_results = _get_raw_results(sampler=sampler, raw_results=raw_results)

trace_x = np.array([raw_results.samples])
trace_neglogpost = -np.array([raw_results.logl])

# the sampler uses custom adaptive priors
trace_neglogprior = np.full(trace_neglogpost.shape, np.nan)
Expand All @@ -299,36 +381,40 @@ def get_original_dynesty_samples(sampler) -> McmcPtResult:
return result


def get_mcmc_like_dynesty_samples(sampler) -> McmcPtResult:
def get_mcmc_like_dynesty_samples(
sampler: DynestySampler = None,
raw_results: dynesty.results.Results = None,
) -> McmcPtResult:
"""Get MCMC-like samples.

Only one of `sampler` or `raw_results` should be provided.

Parameters
----------
sampler:
The (internal!) dynesty sampler. See
`pypesto.sample.DynestySampler.__init__`, specifically the
`save_internal` argument, for more details.
The pyPESTO `DynestySampler` object with sampling results.
raw_results:
The raw results. See :func:`save_raw_results` and
:func:`load_raw_results`.

Returns
-------
The sample result.
"""
import dynesty

setup_dynesty()
raw_results = _get_raw_results(sampler=sampler, raw_results=raw_results)

if len(sampler.results.importance_weights().shape) != 1:
if len(raw_results.importance_weights().shape) != 1:
raise ValueError(
"Unknown error. The dynesty importance weights are not a 1D array."
)
# resample according to importance weights
indices = dynesty.utils.resample_equal(
np.arange(sampler.results.importance_weights().shape[0]),
sampler.results.importance_weights(),
np.arange(raw_results.importance_weights().shape[0]),
raw_results.importance_weights(),
)

trace_x = np.array([sampler.results.samples[indices]])
trace_neglogpost = -np.array([sampler.results.logl[indices]])
trace_x = np.array([raw_results.samples[indices]])
trace_neglogpost = -np.array([raw_results.logl[indices]])

trace_neglogprior = np.array([np.full((len(indices),), np.nan)])
betas = np.array([1.0])
Expand All @@ -340,14 +426,3 @@ def get_mcmc_like_dynesty_samples(sampler) -> McmcPtResult:
betas=betas,
)
return result


def setup_dynesty() -> None:
"""Import dynesty."""
try:
import cloudpickle # noqa: S403
import dynesty.utils

dynesty.utils.pickle_module = cloudpickle
except ImportError:
raise SamplerImportError("dynesty") from None
Loading