Skip to content

Refactored stats_refitting.py for better robustness and readability #2424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions arviz/stats/stats_refitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@
_log = logging.getLogger(__name__)


def _get_scale_value(scale):
"""Convert a scale string into its corresponding numeric factor."""
scale = scale.lower()
mapping = {
"deviance": -2,
"log": 1,
"negative_log": -1,
}
try:
return mapping[scale]
except KeyError:
raise ValueError(f"Unsupported scale '{scale}'. Valid options are: {list(mapping.keys())}")


def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):
"""Recalculate exact Leave-One-Out cross validation refitting where the approximation fails.

Expand Down Expand Up @@ -72,46 +86,52 @@ def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True):
Sampling wrappers are an experimental feature in a very early stage. Please use them
with caution.
"""
# Ensure the wrapper implements all required methods
required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i")
not_implemented = wrapper.check_implemented_methods(required_methods)
if not_implemented:
missing_methods = wrapper.check_implemented_methods(required_methods)
if missing_methods:
raise TypeError(
"Passed wrapper instance does not implement all methods required for reloo "
f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be "
"implemented and were not found."
f"to work. Missing implementations: {missing_methods}"
)


if loo_orig is None:
loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale)

loo_refitted = loo_orig.copy()
khats = loo_refitted.pareto_k
loo_i = loo_refitted.loo_i
scale = loo_orig.scale

if scale.lower() == "deviance":
scale_value = -2
elif scale.lower() == "log":
scale_value = 1
elif scale.lower() == "negative_log":
scale_value = -1
scale = loo_orig.scale
scale_value = _get_scale_value(scale)


lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value
n_data_points = loo_orig.n_data_points

if verbose:
warnings.warn("reloo is an experimental and untested feature", UserWarning)

if np.any(khats > k_thresh):
for idx in np.argwhere(khats.values > k_thresh):
#find the indices where the Pareto k exceeds the threshold
problematic_indices = np.flatnonzero(khats.values > k_thresh)
if problematic_indices.size:
for idx in problematic_indices:
if verbose:
_log.info("Refitting model excluding observation %d", idx)
#exclude the problematic observation and sample with the new dataset
new_obs, excluded_obs = wrapper.sel_observations(idx)
fit = wrapper.sample(new_obs)
idata_idx = wrapper.get_inference_data(fit)
log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx))
#compute exact log likelihood for the excluded observation
log_like = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten()
loo_exact = scale_value * _logsumexp(log_like, b_inv=len(log_like))
#update the LOO results: set the Pareto k to 0 and replace loo_i with exact value
khats[idx] = 0
loo_i[idx] = loo_lppd_idx
loo_i[idx] = loo_exact

#update the overall loo_refitted object with refitted values
loo_refitted.elpd_loo = loo_i.values.sum()
loo_refitted.se = (n_data_points * np.var(loo_i.values)) ** 0.5
loo_refitted.se = np.sqrt(n_data_points * np.var(loo_i.values))
loo_refitted.p_loo = lppd_orig - loo_refitted.elpd_loo / scale_value
return loo_refitted
else:
Expand Down
71 changes: 27 additions & 44 deletions arviz/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""Base class for sampling wrappers."""
from xarray import apply_ufunc

# from ..data import InferenceData
from ..stats import wrap_xarray_ufunc as _wrap_xarray_ufunc


Expand All @@ -13,7 +12,7 @@ class SamplingWrapper:
functions requiring refitting like Leave Future Out or Simulation Based Calibration can be
performed from ArviZ.

For usage examples see user guide pages on :ref:`wrapper_guide`.See other
For usage examples see user guide pages on :ref:`wrapper_guide`. See other
SamplingWrapper classes at :ref:`wrappers api section <wrappers_api>`.

Parameters
Expand All @@ -25,7 +24,7 @@ class SamplingWrapper:
log_lik_fun : callable, optional
For simple cases where the pointwise log likelihood is a Python function, this
function will be used to calculate the log likelihood. Otherwise,
``point_log_likelihood`` method must be implemented. It's callback must be
``point_log_likelihood`` method must be implemented. Its callback must be
``log_lik_fun(*args, **log_lik_kwargs)`` and will be called using
:func:`wrap_xarray_ufunc` or :func:`xarray:xarray.apply_ufunc` depending
on the value of `is_ufunc`.
Expand All @@ -49,7 +48,6 @@ class SamplingWrapper:
apply_ufunc_kwargs : dict, optional
Passed to :func:`xarray:xarray.apply_ufunc` or :func:`wrap_xarray_ufunc`.


Warnings
--------
Sampling wrappers are an experimental feature in a very early stage. Please use them
Expand All @@ -73,9 +71,6 @@ def __init__(
apply_ufunc_kwargs=None,
):
self.model = model

# if not isinstance(idata_orig, InferenceData) or idata_orig is not None:
# raise TypeError("idata_orig must be of InferenceData type or None")
self.idata_orig = idata_orig

if log_lik_fun is None or callable(log_lik_fun):
Expand All @@ -85,16 +80,15 @@ def __init__(
else:
raise TypeError("log_like_fun must be a callable object or None")

self.sample_kwargs = {} if sample_kwargs is None else sample_kwargs
self.idata_kwargs = {} if idata_kwargs is None else idata_kwargs
self.log_lik_kwargs = {} if log_lik_kwargs is None else log_lik_kwargs
self.apply_ufunc_kwargs = {} if apply_ufunc_kwargs is None else apply_ufunc_kwargs
self.sample_kwargs = sample_kwargs or {}
self.idata_kwargs = idata_kwargs or {}
self.log_lik_kwargs = log_lik_kwargs or {}
self.apply_ufunc_kwargs = apply_ufunc_kwargs or {}

def sel_observations(self, idx):
"""Select a subset of the observations in idata_orig.

**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
It is documented here to show its format and call signature.

Parameters
----------
Expand All @@ -114,7 +108,6 @@ def sample(self, modified_observed_data):
"""Sample ``self.model`` on the ``modified_observed_data`` subset.

**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
It is documented here to show its format and call signature.

Parameters
----------
Expand All @@ -132,7 +125,6 @@ def get_inference_data(self, fitted_model):
"""Convert the ``fitted_model`` to an InferenceData object.

**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
It is documented here to show its format and call signature.

Parameters
----------
Expand All @@ -147,7 +139,7 @@ def get_inference_data(self, fitted_model):
raise NotImplementedError("get_inference_data method must be implemented for each subclass")

def log_likelihood__i(self, excluded_obs, idata__i):
r"""Get the log likelilhood samples :math:`\log p_{post(-i)}(y_i)`.
r"""Get the log likelihood samples :math:`\log p_{post(-i)}(y_i)`.

Calculate the log likelihood of the data contained in excluded_obs using the
model fitted with this data excluded, the results of which are stored in ``idata__i``.
Expand All @@ -163,74 +155,65 @@ def log_likelihood__i(self, excluded_obs, idata__i):

Returns
-------
log_likelihood: xr.Dataarray
log_likelihood: xr.DataArray
Log likelihood of ``excluded_obs`` evaluated at each of the posterior samples
stored in ``idata__i``.
"""
if self.log_lik_fun is None:
raise NotImplementedError(
"When `log_like_fun` is not set during class initialization "
"When `log_like_fun` is not set during class initialization, "
"log_likelihood__i method must be overwritten"
)
posterior = idata__i.posterior
arys = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars])
args = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars])
ufunc_applier = apply_ufunc if self.is_ufunc else _wrap_xarray_ufunc
log_lik_idx = ufunc_applier(
return ufunc_applier(
self.log_lik_fun,
*arys,
*args,
kwargs=self.log_lik_kwargs,
**self.apply_ufunc_kwargs,
)
return log_lik_idx

def _check_method_is_implemented(self, method, *args):
"""Check a given method is implemented."""
"""Check if a given method is implemented."""
try:
getattr(self, method)(*args)
except NotImplementedError:
return False
except: # pylint: disable=bare-except
except Exception:
return True
return True

def check_implemented_methods(self, methods):
"""Check that all methods listed are implemented.

Not all functions that require refitting need to have all the methods implemented in
order to work properly. This function shoulg be used before using the SamplingWrapper and
its subclasses to get informative error messages.

Parameters
----------
methods: list
Check all elements in methods are implemented.
List of method names to check.

Returns
-------
List with all non implemented methods
list
List with all non-implemented methods.
"""
supported_methods_1arg = (
"sel_observations",
"sample",
"get_inference_data",
)
supported_methods_1arg = ("sel_observations", "sample", "get_inference_data")
supported_methods_2args = ("log_likelihood__i",)
supported_methods = [*supported_methods_1arg, *supported_methods_2args]
bad_methods = [method for method in methods if method not in supported_methods]
if bad_methods:

invalid_methods = [method for method in methods if method not in supported_methods]
if invalid_methods:
raise ValueError(
f"Not all method(s) in {bad_methods} supported. "
f"Supported methods in SamplingWrapper subclasses are:{supported_methods}"
f"Not all method(s) in {invalid_methods} supported. "
f"Supported methods in SamplingWrapper subclasses are: {supported_methods}"
)

not_implemented = []
for method in methods:
if method in supported_methods_1arg:
if self._check_method_is_implemented(method, 1):
continue
not_implemented.append(method)
if not self._check_method_is_implemented(method, 1):
not_implemented.append(method)
elif method in supported_methods_2args:
if self._check_method_is_implemented(method, 1, 1):
continue
not_implemented.append(method)
if not self._check_method_is_implemented(method, 1, 1):
not_implemented.append(method)
return not_implemented
Loading