-
-
Notifications
You must be signed in to change notification settings - Fork 435
Add loo_expectation #2301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add loo_expectation #2301
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -44,6 +44,7 @@ | |||||
"compare", | ||||||
"hdi", | ||||||
"loo", | ||||||
"loo_expectation", | ||||||
"loo_pit", | ||||||
"psislw", | ||||||
"r2_samples", | ||||||
|
@@ -865,6 +866,60 @@ def loo(data, pointwise=None, var_name=None, reff=None, scale=None): | |||||
], | ||||||
) | ||||||
|
||||||
def loo_expectation(data, values, pointwise=None, reff=None, **kwargs): | ||||||
""" | ||||||
Computes the expectation of values with respect to the leave-one-out posteriors using PSIS. | ||||||
Parameters | ||||||
---------- | ||||||
data: obj | ||||||
Any object that can be converted to an :class:`arviz.InferenceData` object. | ||||||
Refer to documentation of :func:`arviz.convert_to_dataset` for details. | ||||||
values: ndarray | ||||||
A vector of quantities to compute expectations for. | ||||||
pointwise: bool, optional | ||||||
If True the pointwise predictive accuracy will be returned. Defaults to | ||||||
``stats.ic_pointwise`` rcParam. | ||||||
Comment on lines
+879
to
+881
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be removed, it is not used anywhere |
||||||
reff: float, optional | ||||||
Relative MCMC efficiency, ``ess / n`` i.e. number of effective samples | ||||||
divided by the number of actual samples. Computed from trace by default. | ||||||
**kwargs: | ||||||
Additional keyword arguments to pass to the `psislw` function. | ||||||
Comment on lines
+885
to
+886
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. psislw only takes two arguments, which are already provided explicitly, so passing any kwargs here would end up as a keyword not recognized error when calling psislw. |
||||||
Returns | ||||||
------- | ||||||
expectation: float | ||||||
The computed expectation of `values` across LOO posteriors. | ||||||
Comment on lines
+889
to
+890
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should also have an indication of the expected output shape. From the examples in https://mc-stan.org/loo/reference/E_loo.html it looks like it should have the shape of pointwise log likelihood values minus |
||||||
""" | ||||||
inference_data = convert_to_inference_data(data) | ||||||
log_likelihood = _get_log_likelihood(inference_data) | ||||||
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise | ||||||
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw")) | ||||||
shape = log_likelihood.shape | ||||||
n_samples = shape[-1] | ||||||
|
||||||
if reff is None: | ||||||
if not hasattr(inference_data, "posterior"): | ||||||
raise TypeError("Must be able to extract a posterior group from data.") | ||||||
posterior = inference_data.posterior | ||||||
n_chains = len(posterior.chain) | ||||||
if n_chains == 1: | ||||||
reff = 1.0 | ||||||
else: | ||||||
ess_p = ess(posterior, method="mean") | ||||||
# this mean is over all data variables | ||||||
reff = ( | ||||||
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() | ||||||
/ n_samples | ||||||
) | ||||||
|
||||||
log_weights, _ = psislw(-log_likelihood, reff=reff, **kwargs) | ||||||
|
||||||
# Numerically stable Weighted sum | ||||||
# Do computations in the log-space for numerical stability | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right before that I would add a check for DataArrays (preferred input type) to see if they have |
||||||
w_exp = log_weights + np.log(np.abs(values)) | ||||||
_expectation = (np.sign(values) * np.exp(w_exp)).sum() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
The variable is only defined within the scope of the function, no need to add any underscore to the name. |
||||||
|
||||||
return _expectation | ||||||
|
||||||
|
||||||
def psislw(log_weights, reff=1.0): | ||||||
""" | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
waic, | ||
weight_predictions, | ||
_calculate_ics, | ||
loo_expectation, | ||
) | ||
from ...stats.stats import _gpinv | ||
from ...stats.stats_utils import get_log_likelihood | ||
|
@@ -538,6 +539,13 @@ def test_loo_warning(centered_eight): | |
assert loo(centered_eight, pointwise=True) is not None | ||
assert any("Estimated shape parameter" in str(record.message) for record in records) | ||
|
||
@pytest.mark.parametrize("reff", [None, 0.5, 1.0]) | ||
def test_loo_expectation(centered_eight, reff): | ||
log_likelihood = get_log_likelihood(centered_eight) | ||
log_likelihood = log_likelihood.stack(__sample__=("chain", "draw")) | ||
values = np.arange(1, log_likelihood.shape[-1] + 1) | ||
Comment on lines
+544
to
+546
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think values should the the log likelihood directly, no extra processing or using its shape to create different objects. |
||
expectation = loo_expectation(centered_eight, values, pointwise=None, reff=reff) | ||
assert expectation is not None | ||
|
||
@pytest.mark.parametrize("scale", ["log", "negative_log", "deviance"]) | ||
def test_loo_print(centered_eight, scale): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add information about the shape of
values
here. My understanding is that is should be an array/dataarray with the same shape as the pointwise log likelihood (e.g.chain, draw, obs_id
) which in general won't work here. I think there should be an extra check for when the input is a dataarray so thatchain, draw
dimensions get stacked into__sample__
one, otherwise something like:would not work as is.