diff --git a/arviz/stats/stats_refitting.py b/arviz/stats/stats_refitting.py index dbe3a5d49f..8475adfff5 100644 --- a/arviz/stats/stats_refitting.py +++ b/arviz/stats/stats_refitting.py @@ -13,6 +13,20 @@ _log = logging.getLogger(__name__) +def _get_scale_value(scale): + """Convert a scale string into its corresponding numeric factor.""" + scale = scale.lower() + mapping = { + "deviance": -2, + "log": 1, + "negative_log": -1, + } + try: + return mapping[scale] + except KeyError: + raise ValueError(f"Unsupported scale '{scale}'. Valid options are: {list(mapping.keys())}") + + def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True): """Recalculate exact Leave-One-Out cross validation refitting where the approximation fails. @@ -72,46 +86,52 @@ def reloo(wrapper, loo_orig=None, k_thresh=0.7, scale=None, verbose=True): Sampling wrappers are an experimental feature in a very early stage. Please use them with caution. """ + # Ensure the wrapper implements all required methods required_methods = ("sel_observations", "sample", "get_inference_data", "log_likelihood__i") - not_implemented = wrapper.check_implemented_methods(required_methods) - if not_implemented: + missing_methods = wrapper.check_implemented_methods(required_methods) + if missing_methods: raise TypeError( "Passed wrapper instance does not implement all methods required for reloo " - f"to work. Check the documentation of SamplingWrapper. {not_implemented} must be " - "implemented and were not found." + f"to work. Missing implementations: {missing_methods}" ) + + if loo_orig is None: loo_orig = loo(wrapper.idata_orig, pointwise=True, scale=scale) + loo_refitted = loo_orig.copy() khats = loo_refitted.pareto_k loo_i = loo_refitted.loo_i - scale = loo_orig.scale - - if scale.lower() == "deviance": - scale_value = -2 - elif scale.lower() == "log": - scale_value = 1 - elif scale.lower() == "negative_log": - scale_value = -1 + scale = loo_orig.scale + scale_value = _get_scale_value(scale) + + lppd_orig = loo_orig.p_loo + loo_orig.elpd_loo / scale_value n_data_points = loo_orig.n_data_points if verbose: warnings.warn("reloo is an experimental and untested feature", UserWarning) - if np.any(khats > k_thresh): - for idx in np.argwhere(khats.values > k_thresh): + #find the indices where the Pareto k exceeds the threshold + problematic_indices = np.flatnonzero(khats.values > k_thresh) + if problematic_indices.size: + for idx in problematic_indices: if verbose: _log.info("Refitting model excluding observation %d", idx) + #exclude the problematic observation and sample with the new dataset new_obs, excluded_obs = wrapper.sel_observations(idx) fit = wrapper.sample(new_obs) idata_idx = wrapper.get_inference_data(fit) - log_like_idx = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten() - loo_lppd_idx = scale_value * _logsumexp(log_like_idx, b_inv=len(log_like_idx)) + #compute exact log likelihood for the excluded observation + log_like = wrapper.log_likelihood__i(excluded_obs, idata_idx).values.flatten() + loo_exact = scale_value * _logsumexp(log_like, b_inv=len(log_like)) + #update the LOO results: set the Pareto k to 0 and replace loo_i with exact value khats[idx] = 0 - loo_i[idx] = loo_lppd_idx + loo_i[idx] = loo_exact + + #update the overall loo_refitted object with refitted values loo_refitted.elpd_loo = loo_i.values.sum() - loo_refitted.se = (n_data_points * np.var(loo_i.values)) ** 0.5 + loo_refitted.se = np.sqrt(n_data_points * np.var(loo_i.values)) loo_refitted.p_loo = lppd_orig - loo_refitted.elpd_loo / scale_value return loo_refitted else: diff --git a/arviz/wrappers/base.py b/arviz/wrappers/base.py index ad20ce12af..2d7e7f7c82 100644 --- a/arviz/wrappers/base.py +++ b/arviz/wrappers/base.py @@ -2,7 +2,6 @@ """Base class for sampling wrappers.""" from xarray import apply_ufunc -# from ..data import InferenceData from ..stats import wrap_xarray_ufunc as _wrap_xarray_ufunc @@ -13,7 +12,7 @@ class SamplingWrapper: functions requiring refitting like Leave Future Out or Simulation Based Calibration can be performed from ArviZ. - For usage examples see user guide pages on :ref:`wrapper_guide`.See other + For usage examples see user guide pages on :ref:`wrapper_guide`. See other SamplingWrapper classes at :ref:`wrappers api section `. Parameters @@ -25,7 +24,7 @@ class SamplingWrapper: log_lik_fun : callable, optional For simple cases where the pointwise log likelihood is a Python function, this function will be used to calculate the log likelihood. Otherwise, - ``point_log_likelihood`` method must be implemented. It's callback must be + ``point_log_likelihood`` method must be implemented. Its callback must be ``log_lik_fun(*args, **log_lik_kwargs)`` and will be called using :func:`wrap_xarray_ufunc` or :func:`xarray:xarray.apply_ufunc` depending on the value of `is_ufunc`. @@ -49,7 +48,6 @@ class SamplingWrapper: apply_ufunc_kwargs : dict, optional Passed to :func:`xarray:xarray.apply_ufunc` or :func:`wrap_xarray_ufunc`. - Warnings -------- Sampling wrappers are an experimental feature in a very early stage. Please use them @@ -73,9 +71,6 @@ def __init__( apply_ufunc_kwargs=None, ): self.model = model - - # if not isinstance(idata_orig, InferenceData) or idata_orig is not None: - # raise TypeError("idata_orig must be of InferenceData type or None") self.idata_orig = idata_orig if log_lik_fun is None or callable(log_lik_fun): @@ -85,16 +80,15 @@ def __init__( else: raise TypeError("log_like_fun must be a callable object or None") - self.sample_kwargs = {} if sample_kwargs is None else sample_kwargs - self.idata_kwargs = {} if idata_kwargs is None else idata_kwargs - self.log_lik_kwargs = {} if log_lik_kwargs is None else log_lik_kwargs - self.apply_ufunc_kwargs = {} if apply_ufunc_kwargs is None else apply_ufunc_kwargs + self.sample_kwargs = sample_kwargs or {} + self.idata_kwargs = idata_kwargs or {} + self.log_lik_kwargs = log_lik_kwargs or {} + self.apply_ufunc_kwargs = apply_ufunc_kwargs or {} def sel_observations(self, idx): """Select a subset of the observations in idata_orig. **Not implemented**: This method must be implemented by the SamplingWrapper subclasses. - It is documented here to show its format and call signature. Parameters ---------- @@ -114,7 +108,6 @@ def sample(self, modified_observed_data): """Sample ``self.model`` on the ``modified_observed_data`` subset. **Not implemented**: This method must be implemented by the SamplingWrapper subclasses. - It is documented here to show its format and call signature. Parameters ---------- @@ -132,7 +125,6 @@ def get_inference_data(self, fitted_model): """Convert the ``fitted_model`` to an InferenceData object. **Not implemented**: This method must be implemented by the SamplingWrapper subclasses. - It is documented here to show its format and call signature. Parameters ---------- @@ -147,7 +139,7 @@ def get_inference_data(self, fitted_model): raise NotImplementedError("get_inference_data method must be implemented for each subclass") def log_likelihood__i(self, excluded_obs, idata__i): - r"""Get the log likelilhood samples :math:`\log p_{post(-i)}(y_i)`. + r"""Get the log likelihood samples :math:`\log p_{post(-i)}(y_i)`. Calculate the log likelihood of the data contained in excluded_obs using the model fitted with this data excluded, the results of which are stored in ``idata__i``. @@ -163,74 +155,65 @@ def log_likelihood__i(self, excluded_obs, idata__i): Returns ------- - log_likelihood: xr.Dataarray + log_likelihood: xr.DataArray Log likelihood of ``excluded_obs`` evaluated at each of the posterior samples stored in ``idata__i``. """ if self.log_lik_fun is None: raise NotImplementedError( - "When `log_like_fun` is not set during class initialization " + "When `log_like_fun` is not set during class initialization, " "log_likelihood__i method must be overwritten" ) posterior = idata__i.posterior - arys = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars]) + args = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars]) ufunc_applier = apply_ufunc if self.is_ufunc else _wrap_xarray_ufunc - log_lik_idx = ufunc_applier( + return ufunc_applier( self.log_lik_fun, - *arys, + *args, kwargs=self.log_lik_kwargs, **self.apply_ufunc_kwargs, ) - return log_lik_idx def _check_method_is_implemented(self, method, *args): - """Check a given method is implemented.""" + """Check if a given method is implemented.""" try: getattr(self, method)(*args) except NotImplementedError: return False - except: # pylint: disable=bare-except + except Exception: return True return True def check_implemented_methods(self, methods): """Check that all methods listed are implemented. - Not all functions that require refitting need to have all the methods implemented in - order to work properly. This function shoulg be used before using the SamplingWrapper and - its subclasses to get informative error messages. - Parameters ---------- methods: list - Check all elements in methods are implemented. + List of method names to check. Returns ------- - List with all non implemented methods + list + List with all non-implemented methods. """ - supported_methods_1arg = ( - "sel_observations", - "sample", - "get_inference_data", - ) + supported_methods_1arg = ("sel_observations", "sample", "get_inference_data") supported_methods_2args = ("log_likelihood__i",) supported_methods = [*supported_methods_1arg, *supported_methods_2args] - bad_methods = [method for method in methods if method not in supported_methods] - if bad_methods: + + invalid_methods = [method for method in methods if method not in supported_methods] + if invalid_methods: raise ValueError( - f"Not all method(s) in {bad_methods} supported. " - f"Supported methods in SamplingWrapper subclasses are:{supported_methods}" + f"Not all method(s) in {invalid_methods} supported. " + f"Supported methods in SamplingWrapper subclasses are: {supported_methods}" ) not_implemented = [] for method in methods: if method in supported_methods_1arg: - if self._check_method_is_implemented(method, 1): - continue - not_implemented.append(method) + if not self._check_method_is_implemented(method, 1): + not_implemented.append(method) elif method in supported_methods_2args: - if self._check_method_is_implemented(method, 1, 1): - continue - not_implemented.append(method) + if not self._check_method_is_implemented(method, 1, 1): + not_implemented.append(method) return not_implemented