diff --git a/docs/source/api/index.md b/docs/source/api/index.md index 829b6eb..0e9191f 100644 --- a/docs/source/api/index.md +++ b/docs/source/api/index.md @@ -29,6 +29,7 @@ arviz_base.from_cmdstanpy arviz_base.from_emcee + arviz_base.from_numpyro ``` More coming soon... diff --git a/external_tests/helpers.py b/external_tests/helpers.py index 2af48c2..936f3e5 100644 --- a/external_tests/helpers.py +++ b/external_tests/helpers.py @@ -191,7 +191,7 @@ def load_cached_models(eight_schools_data, draws, chains, libs=None): # ("pystan", pystan_noncentered_schools), ("emcee", emcee_schools_model), # ("pyro", pyro_noncentered_schools), - # ("numpyro", numpyro_schools_model), + ("numpyro", numpyro_schools_model), ) data_directory = os.path.join(here, "saved_models") if not os.path.isdir(data_directory): diff --git a/external_tests/test_numpyro.py b/external_tests/test_numpyro.py new file mode 100644 index 0000000..b166c2b --- /dev/null +++ b/external_tests/test_numpyro.py @@ -0,0 +1,275 @@ +# pylint: disable=no-member, invalid-name, redefined-outer-name +from collections import namedtuple + +import numpy as np +import pytest + +from arviz_base.io_numpyro import from_numpyro +from arviz_base.testing import check_multiple_attrs + +from .helpers import importorskip, load_cached_models + +# Skip all tests if jax or numpyro not installed +jax = importorskip("jax") +PRNGKey = jax.random.PRNGKey +numpyro = importorskip("numpyro") +Predictive = numpyro.infer.Predictive +numpyro.set_host_device_count(2) + + +class TestDataNumPyro: + @pytest.fixture(scope="class") + def data(self, eight_schools_params, draws, chains): + class Data: + obj = load_cached_models(eight_schools_params, draws, chains, "numpyro")["numpyro"] + + return Data + + @pytest.fixture(scope="class") + def predictions_params(self): + """Predictions data for eight schools.""" + return { + "J": 8, + "sigma": np.array([5.0, 7.0, 12.0, 4.0, 6.0, 10.0, 3.0, 9.0]), + } + + @pytest.fixture(scope="class") + def predictions_data(self, data, predictions_params): + """Generate predictions for predictions_params""" + posterior_samples = data.obj.get_samples() + model = data.obj.sampler.model + predictions = Predictive(model, posterior_samples)( + PRNGKey(2), predictions_params["J"], predictions_params["sigma"] + ) + return predictions + + def get_inference_data(self, data, eight_schools_params, predictions_data, predictions_params): + posterior_samples = data.obj.get_samples() + model = data.obj.sampler.model + posterior_predictive = Predictive(model, posterior_samples)( + PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"] + ) + prior = Predictive(model, num_samples=500)( + PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"] + ) + predictions = predictions_data + return from_numpyro( + posterior=data.obj, + prior=prior, + posterior_predictive=posterior_predictive, + predictions=predictions, + coords={ + "school": np.arange(eight_schools_params["J"]), + "school_pred": np.arange(predictions_params["J"]), + }, + dims={"theta": ["school"], "eta": ["school"], "obs": ["school"]}, + pred_dims={"theta": ["school_pred"], "eta": ["school_pred"], "obs": ["school_pred"]}, + ) + + def test_inference_data_namedtuple(self, data): + samples = data.obj.get_samples() + Samples = namedtuple("Samples", samples) + data_namedtuple = Samples(**samples) + _old_fn = data.obj.get_samples + data.obj.get_samples = lambda *args, **kwargs: data_namedtuple + inference_data = from_numpyro( + posterior=data.obj, + ) + assert isinstance(data.obj.get_samples(), Samples) + data.obj.get_samples = _old_fn + for key in samples: + assert key in inference_data.posterior + + def test_inference_data(self, data, eight_schools_params, predictions_data, predictions_params): + inference_data = self.get_inference_data( + data, eight_schools_params, predictions_data, predictions_params + ) + test_dict = { + "posterior": ["mu", "tau", "eta"], + "sample_stats": ["diverging"], + "posterior_predictive": ["obs"], + "predictions": ["obs"], + "prior": ["mu", "tau", "eta"], + "prior_predictive": ["obs"], + "observed_data": ["obs"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + ## test dims + dims = inference_data.posterior_predictive.sizes["school"] + pred_dims = inference_data.predictions.sizes["school_pred"] + assert dims == 8 + assert pred_dims == 8 + + def test_inference_data_no_posterior( + self, data, eight_schools_params, predictions_data, predictions_params + ): + posterior_samples = data.obj.get_samples() + model = data.obj.sampler.model + posterior_predictive = Predictive(model, posterior_samples)( + PRNGKey(1), eight_schools_params["J"], eight_schools_params["sigma"] + ) + prior = Predictive(model, num_samples=500)( + PRNGKey(2), eight_schools_params["J"], eight_schools_params["sigma"] + ) + predictions = predictions_data + constant_data = {"J": 8, "sigma": eight_schools_params["sigma"]} + predictions_constant_data = predictions_params + ## only prior + inference_data = from_numpyro(prior=prior) + test_dict = {"prior": ["mu", "tau", "eta"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, f"only prior: {fails}" + ## only posterior_predictive + inference_data = from_numpyro(posterior_predictive=posterior_predictive) + test_dict = {"posterior_predictive": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, f"only posterior_predictive: {fails}" + ## only predictions + inference_data = from_numpyro(predictions=predictions) + test_dict = {"predictions": ["obs"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, f"only predictions: {fails}" + ## only constant_data + inference_data = from_numpyro(constant_data=constant_data) + test_dict = {"constant_data": ["J", "sigma"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, f"only constant_data: {fails}" + ## only predictions_constant_data + inference_data = from_numpyro(predictions_constant_data=predictions_constant_data) + test_dict = {"predictions_constant_data": ["J", "sigma"]} + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails, f"only predictions_constant_data: {fails}" + prior and posterior_predictive + idata = from_numpyro( + prior=prior, + posterior_predictive=posterior_predictive, + coords={"school": np.arange(eight_schools_params["J"])}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + test_dict = {"posterior_predictive": ["obs"], "prior": ["mu", "tau", "eta", "obs"]} + fails = check_multiple_attrs(test_dict, idata) + assert not fails, f"prior and posterior_predictive: {fails}" + + def test_inference_data_only_posterior(self, data): + idata = from_numpyro(data.obj) + test_dict = { + "posterior": ["mu", "tau", "eta"], + "sample_stats": ["diverging"], + } + fails = check_multiple_attrs(test_dict, idata) + assert not fails + + def test_multiple_observed_rv(self): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, NUTS + + rng = np.random.default_rng() + y1 = rng.normal(size=10) + y2 = rng.normal(size=100) + + def model_example_multiple_obs(y1=None, y2=None): + x = numpyro.sample("x", dist.Normal(1, 3)) + numpyro.sample("y1", dist.Normal(x, 1), obs=y1) + numpyro.sample("y2", dist.Normal(x, 1), obs=y2) + + nuts_kernel = NUTS(model_example_multiple_obs) + mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) + mcmc.run(PRNGKey(0), y1=y1, y2=y2) + inference_data = from_numpyro(mcmc) + test_dict = { + "posterior": ["x"], + "sample_stats": ["diverging"], + "observed_data": ["y1", "y2"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + assert not hasattr(inference_data.sample_stats, "log_likelihood") + + def test_inference_data_constant_data(self): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, NUTS + + x1 = 10 + x2 = 12 + rng = np.random.default_rng() + y1 = rng.normal(size=10) + + def model_constant_data(x, y1=None): + _x = numpyro.sample("x", dist.Normal(1, 3)) + numpyro.sample("y1", dist.Normal(x * _x, 1), obs=y1) + + nuts_kernel = NUTS(model_constant_data) + mcmc = MCMC(nuts_kernel, num_samples=10, num_warmup=2) + mcmc.run(PRNGKey(0), x=x1, y1=y1) + posterior = mcmc.get_samples() + posterior_predictive = Predictive(model_constant_data, posterior)(PRNGKey(1), x1) + predictions = Predictive(model_constant_data, posterior)(PRNGKey(2), x2) + inference_data = from_numpyro( + mcmc, + posterior_predictive=posterior_predictive, + predictions=predictions, + constant_data={"x1": x1}, + predictions_constant_data={"x2": x2}, + ) + test_dict = { + "posterior": ["x"], + "posterior_predictive": ["y1"], + "sample_stats": ["diverging"], + "predictions": ["y1"], + "observed_data": ["y1"], + "constant_data": ["x1"], + "predictions_constant_data": ["x2"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + def test_inference_data_num_chains(self, predictions_data, chains): + predictions = predictions_data + inference_data = from_numpyro(predictions=predictions, num_chains=chains) + nchains = inference_data.predictions.sizes["chain"] + assert nchains == chains + + @pytest.mark.parametrize("nchains", [1, 2]) + @pytest.mark.parametrize("thin", [1, 2, 3, 5, 10]) + def test_mcmc_with_thinning(self, nchains, thin): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, NUTS + + rng = np.random.default_rng() + x = rng.normal(10, 3, size=100) + + def model(x): + numpyro.sample( + "x", + dist.Normal( + numpyro.sample("loc", dist.Uniform(0, 20)), + numpyro.sample("scale", dist.Uniform(0, 20)), + ), + obs=x, + ) + + nuts_kernel = NUTS(model) + mcmc = MCMC(nuts_kernel, num_warmup=100, num_samples=400, num_chains=nchains, thinning=thin) + mcmc.run(PRNGKey(0), x=x) + + inference_data = from_numpyro(mcmc) + assert inference_data.posterior["loc"].shape == (nchains, 400 // thin) + + def test_mcmc_improper_uniform(self): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import MCMC, NUTS + + def model(): + x = numpyro.sample("x", dist.ImproperUniform(dist.constraints.positive, (), ())) + return numpyro.sample("y", dist.Normal(x, 1), obs=1.0) + + mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10) + mcmc.run(PRNGKey(0)) + inference_data = from_numpyro(mcmc) + assert inference_data.observed_data diff --git a/src/arviz_base/__init__.py b/src/arviz_base/__init__.py index 6e93b44..902719c 100644 --- a/src/arviz_base/__init__.py +++ b/src/arviz_base/__init__.py @@ -12,6 +12,7 @@ from arviz_base.io_cmdstanpy import from_cmdstanpy from arviz_base.io_dict import from_dict from arviz_base.io_emcee import from_emcee +from arviz_base.io_numpyro import from_numpyro from arviz_base.rcparams import rc_context, rcParams from arviz_base.reorg import extract, dataset_to_dataarray, dataset_to_dataframe from arviz_base.sel_utils import * diff --git a/src/arviz_base/io_numpyro.py b/src/arviz_base/io_numpyro.py new file mode 100644 index 0000000..bd152e4 --- /dev/null +++ b/src/arviz_base/io_numpyro.py @@ -0,0 +1,385 @@ +"""NumPyro-specific conversion code.""" + +import warnings + +import numpy as np +from xarray import DataTree + +from arviz_base.base import dict_to_dataset, requires +from arviz_base.rcparams import rc_context, rcParams +from arviz_base.utils import expand_dims + + +class NumPyroConverter: + """Encapsulate NumPyro specific logic.""" + + # pylint: disable=too-many-instance-attributes + + model = None + nchains = None + ndraws = None + + def __init__( + self, + *, + posterior=None, + prior=None, + posterior_predictive=None, + predictions=None, + constant_data=None, + predictions_constant_data=None, + log_likelihood=None, + index_origin=None, + coords=None, + dims=None, + pred_dims=None, + num_chains=1, + ): + """Convert NumPyro data into an InferenceData object. + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior + predictions: dict + Out of sample predictions + constant_data: dict + Dictionary containing constant data variables mapped to their values. + predictions_constant_data: dict + Constant data used for out-of-sample predictions. + index_origin : int, optional + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + pred_dims: dict + Dims for predictions data. Map variable names to their coordinates. + num_chains: int + Number of chains used for sampling. Ignored if posterior is present. + """ + import jax + import numpyro + + self.posterior = posterior + self.prior = jax.device_get(prior) + self.posterior_predictive = jax.device_get(posterior_predictive) + self.predictions = predictions + self.constant_data = constant_data + self.predictions_constant_data = predictions_constant_data + self.log_likelihood = ( + rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood + ) + self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin + self.coords = coords + self.dims = dims + self.pred_dims = pred_dims + self.numpyro = numpyro + + def arbitrary_element(dct): + return next(iter(dct.values())) + + if posterior is not None: + samples = jax.device_get(self.posterior.get_samples(group_by_chain=True)) + if hasattr(samples, "_asdict"): + # In case it is easy to convert to a dictionary, as in the case of namedtuples + samples = {k: expand_dims(v) for k, v in samples._asdict().items()} + if not isinstance(samples, dict): + # handle the case we run MCMC with a general potential_fn + # (instead of a NumPyro model) whose args is not a dictionary + # (e.g. f(x) = x ** 2) + tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0] + samples = { + f"Param:{i}": jax.device_get(v) for i, v in enumerate(tree_flatten_samples) + } + self._samples = samples + self.nchains, self.ndraws = ( + posterior.num_chains, + posterior.num_samples // posterior.thinning, + ) + self.model = self.posterior.sampler.model + # model arguments and keyword arguments + self._args = self.posterior._args # pylint: disable=protected-access + self._kwargs = self.posterior._kwargs # pylint: disable=protected-access + else: + self.nchains = num_chains + get_from = None + if predictions is not None: + get_from = predictions + elif posterior_predictive is not None: + get_from = posterior_predictive + elif prior is not None: + get_from = prior + if get_from is None and constant_data is None and predictions_constant_data is None: + raise ValueError( + "When constructing InferenceData must have at least" + " one of posterior, prior, posterior_predictive or predictions." + ) + if get_from is not None: + aelem = arbitrary_element(get_from) + self.ndraws = aelem.shape[0] // self.nchains + + observations = {} + if self.model is not None: + # we need to use an init strategy to generate random samples for ImproperUniform sites + seeded_model = numpyro.handlers.substitute( + numpyro.handlers.seed(self.model, jax.random.PRNGKey(0)), + substitute_fn=numpyro.infer.init_to_sample, + ) + trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs) + observations = { + name: site["value"] + for name, site in trace.items() + if site["type"] == "sample" and site["is_observed"] + } + self.observations = observations if observations else None + + @requires("posterior") + def posterior_to_xarray(self): + """Convert the posterior to an xarray dataset.""" + data = self._samples + return dict_to_dataset( + data, + inference_library=self.numpyro, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ) + + @requires("posterior") + def sample_stats_to_xarray(self): + """Extract sample_stats from NumPyro posterior.""" + rename_key = { + "potential_energy": "lp", + "adapt_state.step_size": "step_size", + "num_steps": "n_steps", + "accept_prob": "acceptance_rate", + } + data = {} + for stat, value in self.posterior.get_extra_fields(group_by_chain=True).items(): + if isinstance(value, dict | tuple): + continue + name = rename_key.get(stat, stat) + value_cp = value.copy() + data[name] = value_cp + if stat == "num_steps": + data["tree_depth"] = np.log2(value_cp).astype(int) + 1 + + return dict_to_dataset( + data, + inference_library=self.numpyro, + dims=None, + coords=self.coords, + index_origin=self.index_origin, + ) + + @requires("posterior") + @requires("model") + def log_likelihood_to_xarray(self): + """Extract log likelihood from NumPyro posterior.""" + if not self.log_likelihood: + return None + data = {} + if self.observations is not None: + samples = self.posterior.get_samples(group_by_chain=False) + if hasattr(samples, "_asdict"): + samples = samples._asdict() + log_likelihood_dict = self.numpyro.infer.log_likelihood( + self.model, samples, *self._args, **self._kwargs + ) + for obs_name, log_like in log_likelihood_dict.items(): + shape = (self.nchains, self.ndraws) + log_like.shape[1:] + data[obs_name] = np.reshape(np.asarray(log_like), shape) + return dict_to_dataset( + data, + inference_library=self.numpyro, + dims=self.dims, + coords=self.coords, + index_origin=self.index_origin, + skip_event_dims=True, + ) + + def translate_posterior_predictive_dict_to_xarray(self, dct, dims): + """Convert posterior_predictive or prediction samples to xarray.""" + data = {} + for k, ary in dct.items(): + shape = ary.shape + if shape[0] == self.nchains and shape[1] == self.ndraws: + data[k] = ary + elif shape[0] == self.nchains * self.ndraws: + data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:])) + else: + data[k] = expand_dims(ary) + warnings.warn( + "posterior predictive shape not compatible with number of chains and draws. " + "This can mean that some draws or even whole chains are not represented." + ) + return dict_to_dataset( + data, + inference_library=self.numpyro, + coords=self.coords, + dims=dims, + index_origin=self.index_origin, + ) + + @requires("posterior_predictive") + def posterior_predictive_to_xarray(self): + """Convert posterior_predictive samples to xarray.""" + return self.translate_posterior_predictive_dict_to_xarray( + self.posterior_predictive, self.dims + ) + + @requires("predictions") + def predictions_to_xarray(self): + """Convert predictions to xarray.""" + return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims) + + def priors_to_xarray(self): + """Convert prior samples (and if possible prior predictive too) to xarray.""" + if self.prior is None: + return {"prior": None, "prior_predictive": None} + if self.posterior is not None: + prior_vars = list(self._samples.keys()) + prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars] + else: + prior_vars = self.prior.keys() + prior_predictive_vars = None + priors_dict = { + group: ( + None + if var_names is None + else dict_to_dataset( + {k: expand_dims(self.prior[k]) for k in var_names}, + inference_library=self.numpyro, + coords=self.coords, + dims=self.dims, + index_origin=self.index_origin, + ) + ) + for group, var_names in zip( + ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars) + ) + } + return priors_dict + + @requires("observations") + @requires("model") + def observed_data_to_xarray(self): + """Convert observed data to xarray.""" + return dict_to_dataset( + self.observations, + inference_library=self.numpyro, + dims=self.dims, + coords=self.coords, + sample_dims=[], + index_origin=self.index_origin, + ) + + @requires("constant_data") + def constant_data_to_xarray(self): + """Convert constant_data to xarray.""" + return dict_to_dataset( + self.constant_data, + inference_library=self.numpyro, + dims=self.dims, + coords=self.coords, + sample_dims=[], + index_origin=self.index_origin, + ) + + @requires("predictions_constant_data") + def predictions_constant_data_to_xarray(self): + """Convert predictions_constant_data to xarray.""" + return dict_to_dataset( + self.predictions_constant_data, + inference_library=self.numpyro, + dims=self.pred_dims, + coords=self.coords, + sample_dims=[], + index_origin=self.index_origin, + ) + + def to_datatree(self): + """Convert all available data to an InferenceData object. + + Note that if groups can not be created (i.e., there is no `trace`, so + the `posterior` and `sample_stats` can not be extracted), then the InferenceData + will not have those groups. + """ + dicto = { + "posterior": self.posterior_to_xarray(), + "sample_stats": self.sample_stats_to_xarray(), + "log_likelihood": self.log_likelihood_to_xarray(), + "posterior_predictive": self.posterior_predictive_to_xarray(), + "predictions": self.predictions_to_xarray(), + **self.priors_to_xarray(), + "observed_data": self.observed_data_to_xarray(), + "constant_data": self.constant_data_to_xarray(), + "predictions_constant_data": self.predictions_constant_data_to_xarray(), + } + + return DataTree.from_dict({group: ds for group, ds in dicto.items() if ds is not None}) + + +def from_numpyro( + posterior=None, + *, + prior=None, + posterior_predictive=None, + predictions=None, + constant_data=None, + predictions_constant_data=None, + log_likelihood=None, + index_origin=None, + coords=None, + dims=None, + pred_dims=None, + num_chains=1, +): + """Convert NumPyro data into an InferenceData object. + + For a usage example read the + :ref:`Creating InferenceData section on from_numpyro ` + + Parameters + ---------- + posterior : numpyro.mcmc.MCMC + Fitted MCMC object from NumPyro + prior: dict + Prior samples from a NumPyro model + posterior_predictive : dict + Posterior predictive samples for the posterior + predictions: dict + Out of sample predictions + constant_data: dict + Dictionary containing constant data variables mapped to their values. + predictions_constant_data: dict + Constant data used for out-of-sample predictions. + index_origin : int, optional + coords : dict[str] -> list[str] + Map of dimensions to coordinates + dims : dict[str] -> list[str] + Map variable names to their coordinates + pred_dims: dict + Dims for predictions data. Map variable names to their coordinates. + num_chains: int + Number of chains used for sampling. Ignored if posterior is present. + """ + with rc_context(rc={"data.sample_dims": ["chain", "draw"]}): + return NumPyroConverter( + posterior=posterior, + prior=prior, + posterior_predictive=posterior_predictive, + predictions=predictions, + constant_data=constant_data, + predictions_constant_data=predictions_constant_data, + log_likelihood=log_likelihood, + index_origin=index_origin, + coords=coords, + dims=dims, + pred_dims=pred_dims, + num_chains=num_chains, + ).to_datatree() diff --git a/src/arviz_base/utils.py b/src/arviz_base/utils.py index 65cd211..9890a16 100644 --- a/src/arviz_base/utils.py +++ b/src/arviz_base/utils.py @@ -182,3 +182,11 @@ def _get_coords(data, coords): except KeyError as err: raise KeyError(f"Error in data[{idx}]: {err}") from err return data_subset + + +def expand_dims(x): + """Jitting numpy expand_dims.""" + if not isinstance(x, np.ndarray): + return np.expand_dims(x, 0) + shape = x.shape + return x.reshape(shape[:0] + (1,) + shape[0:])