Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added plot_bf() function for bayes_factor in arviz-plots #158

Merged
merged 1 commit into from
Mar 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
.. autosummary::
:toctree: generated/

plot_bf
plot_compare
plot_convergence_dist
plot_dist
Expand Down
23 changes: 23 additions & 0 deletions docs/source/gallery/model_comparison/plot_bf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
# Bayes_factor
Compute Bayes factor using Savage–Dickey ratio
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_bf`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

data = load_arviz_data("centered_eight")

pc = azp.plot_bf(
data,
backend="none",
var_name="mu"
)

pc.show()
27 changes: 24 additions & 3 deletions src/arviz_plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,16 @@ def map(
aux_artist = np.squeeze(aux_artist)
self.viz[var_name][fun_label].loc[sel] = aux_artist

def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=None, **kwargs):
def add_legend(
self,
dim,
var_name=None,
aes=None,
artist_kwargs=None,
title=None,
text_only=False,
**kwargs,
):
"""Add a legend for the given artist/aesthetic to the plot.

Warnings
Expand All @@ -1073,6 +1082,8 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non
generate the miniatures in the legend.
title : str, optional
Legend title. Defaults to `dim`.
text_only : bool, optional
If True, creates a text-only legend without graphical markers.
**kwargs : mapping, optional
Keyword arguments passed to the backend function that generates the legend.

Expand All @@ -1098,18 +1109,28 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non
if isinstance(aes, str):
aes = [aes]
aes_ds = aes_ds[aes]

label_list = aes_ds[dim].values
kwarg_list = [
{k: v.item() for k, v in aes_ds.sel({dim: coord}).items()} for coord in label_list
]

for kwarg_dict in kwarg_list:
kwarg_dict.pop("overlay", None)
if text_only:
kwarg_dict.pop("color", None)

plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")

legend_title = None if text_only else title

return plot_bknd.legend(
self.viz["chart"].item(),
kwarg_list,
label_list,
title=title,
artist_kwargs=artist_kwargs,
title=legend_title,
artist_kwargs={"linestyle": "none", "linewidth": 0, "color": "none"}
if text_only
else artist_kwargs,
**kwargs,
)
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Batteries-included ArviZ plots."""

from .bfplot import plot_bf
from .compareplot import plot_compare
from .convergencedistplot import plot_convergence_dist
from .distplot import plot_dist
Expand All @@ -17,6 +18,7 @@
from .traceplot import plot_trace

__all__ = [
"plot_bf",
"plot_compare",
"plot_convergence_dist",
"plot_dist",
Expand Down
210 changes: 210 additions & 0 deletions src/arviz_plots/plots/bfplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Contain functions for Bayes Factor plotting."""

from copy import copy
from importlib import import_module

import xarray as xr
from arviz_base import extract, rcParams
from arviz_stats.bayes_factor import bayes_factor

from arviz_plots.plot_collection import PlotCollection
from arviz_plots.plots.distplot import plot_dist
from arviz_plots.plots.utils import filter_aes
from arviz_plots.visuals import vline


def plot_bf(
dt,
var_name,
ref_val=0,
kind=None,
sample_dims=None,
plot_collection=None,
backend=None,
labeller=None,
aes_map=None,
plot_kwargs=None,
stats_kwargs=None,
pc_kwargs=None,
):
r"""Approximated Bayes Factor for comparing hypothesis of two nested models.

The Bayes factor is estimated by comparing a model (H1) against a model
in which the parameter of interest has been restricted to be a point-null (H0)
This computation assumes the models are nested and thus H0 is a special case of H1.

Parameters
----------
dt : DataTree or dict of {str : DataTree}
Input data. In case of dictionary input, the keys are taken to be model names.
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
ref_val : int or float, default 0
Reference (point-null) value for Bayes factor estimation.
kind : {"kde", "hist", "dot", "ecdf"}, optional
How to represent the marginal density.
Defaults to ``rcParams["plot.density_kind"]``
sample_dims : str or sequence of hashable, optional
Dimensions to reduce unless mapped to an aesthetic.
Defaults to ``rcParams["data.sample_dims"]``
plot_collection : PlotCollection, optional
backend : {"matplotlib", "bokeh"}, optional
labeller : labeller, optional
aes_map : mapping of {str : sequence of str}, optional
Mapping of artists to aesthetics that should use their mapping in `plot_collection`
when plotted. Valid keys are the same as for `plot_kwargs`.
With a single model, no aesthetic mappings are generated by default,
each variable+coord combination gets a :term:`plot` but they all look the same,
unless there are user provided aesthetic mappings.
With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension.
By default, all aesthetics but "y" are mapped to the density representation,
and if multiple models are present, "color" and "y" are mapped to the
credible interval and the point estimate.
When "point_estimate" key is provided but "point_estimate_text" isn't,
the values assigned to the first are also used for the second.
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument.
* "kde" -> passed to :func:`~arviz_plots.visuals.line_xy`
* "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line`
* "hist" -> passed to :func: `~arviz_plots.visuals.hist`
* credible_interval -> passed to :func:`~arviz_plots.visuals.line_x`
* point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x`
* point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`
* remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function
stats_kwargs : mapping, optional
Valid keys are:
* density -> passed to kde, ecdf, ...
* credible_interval -> passed to eti or hdi
* point_estimate -> passed to mean, median or mode
pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.wrap`
"""
if kind is None:
kind = rcParams["plot.density_kind"]
if plot_kwargs is None:
plot_kwargs = {}
else:
plot_kwargs = plot_kwargs.copy()
if pc_kwargs is None:
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()
if aes_map is None:
aes_map = {}
else:
aes_map = aes_map.copy()

if sample_dims is None:
sample_dims = rcParams["data.sample_dims"]
if isinstance(sample_dims, str):
sample_dims = [sample_dims]
sample_dims = list(sample_dims)
if not isinstance(plot_kwargs, dict):
plot_kwargs = {}

ref_line_kwargs = copy(plot_kwargs.get("ref_line", {}))
if ref_line_kwargs is False:
raise ValueError(
"plot_kwargs['ref_line'] can't be False, use ref_val=False to remove this element"
)

if backend is None:
if plot_collection is None:
backend = rcParams["plot.backend"]
else:
backend = plot_collection.backend
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")

ds_prior = extract(dt, group="prior", var_names=var_name, keep_dataset=True)
ds_posterior = extract(dt, group="posterior", var_names=var_name, keep_dataset=True)

distribution = xr.concat([ds_prior, ds_posterior], dim="Groups").assign_coords(
{"Groups": ["prior", "posterior"]}
)

if len(sample_dims) > 1:
# sample dims will have been stacked and renamed by `power_scale_dataset`
sample_dims = ["sample"]

# Compute Bayes Factor using the bayes_factor function
bf, _ = bayes_factor(dt, var_name, ref_val, return_ref_vals=True)

bf_values = xr.DataArray(
[f"BF10: {bf['BF10']:.2f}"],
dims=["BF_type"],
coords={"BF_type": ["BF10"]},
)

if plot_collection is None:
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy()

pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("color", ["Groups"])

figsize = pc_kwargs["plot_grid_kws"].get("figsize", None)
figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches")
if figsize is None:
figsize = plot_bknd.scale_fig_size(
figsize,
rows=1,
cols=1,
figsize_units=figsize_units,
)
figsize_units = "dots"
pc_kwargs["plot_grid_kws"]["figsize"] = figsize
pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units

plot_collection = PlotCollection.grid(
distribution,
backend=backend,
**pc_kwargs,
)

plot_collection.aes["BF_type"] = bf_values

plot_kwargs.setdefault("credible_interval", False)
plot_kwargs.setdefault("point_estimate", False)
plot_kwargs.setdefault("point_estimate_text", False)

plot_collection = plot_dist(
distribution,
var_names=var_name,
group=None,
coords=None,
sample_dims=sample_dims,
kind=kind,
point_estimate=None,
ci_kind=None,
ci_prob=None,
plot_collection=plot_collection,
backend=backend,
labeller=labeller,
plot_kwargs=plot_kwargs,
stats_kwargs=stats_kwargs,
pc_kwargs=pc_kwargs,
)

if ref_val is not False:
ref_dt = xr.Dataset({var_name: xr.DataArray([ref_val])})
_, ref_aes, ref_ignore = filter_aes(plot_collection, aes_map, "ref_line", sample_dims)
if "color" not in ref_aes:
ref_line_kwargs.setdefault("color", "black")
if "linestyle" not in ref_aes:
default_linestyle = plot_bknd.get_default_aes("linestyle", 2, {})[1]
ref_line_kwargs.setdefault("linestyle", default_linestyle)
if "alpha" not in ref_aes:
ref_line_kwargs.setdefault("alpha", 0.5)

plot_collection.map(
vline, "ref_line", data=ref_dt, ignore_aes=ref_ignore, **ref_line_kwargs
)

if backend == "matplotlib": ## remove this when we have a better way to handle legends
plot_collection.add_legend("Groups", loc="upper right", fontsize=10)
plot_collection.add_legend("BF_type", loc="upper left", fontsize=10, text_only=True)

return plot_collection
14 changes: 13 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
import pandas as pd
import pytest
from arviz_base import from_dict
from arviz_base import from_dict, load_arviz_data
from scipy.stats import halfnorm, norm

from arviz_plots import (
plot_bf,
plot_compare,
plot_dist,
plot_ess,
Expand Down Expand Up @@ -461,3 +462,14 @@ def test_plot_psense_dist_sample(self, datatree_sample, backend):
assert "component_group" in pc.viz["mu"]["credible_interval"].dims
assert "alpha" in pc.viz["mu"]["credible_interval"].dims
assert "hierarchy" in pc.viz["theta"].dims

def test_plot_bf(self, datatree, backend):
# The current genrate_base_data() function lacks a "prior" group,
# so we're manually loading the data.
datatree = load_arviz_data("centered_eight")
pc = plot_bf(datatree, var_name="mu", backend=backend)
assert "chart" in pc.viz.data_vars
assert "plot" in pc.viz.data_vars
assert "row" in pc.viz.data_vars
assert "col" in pc.viz.data_vars
assert "Groups" in pc.viz["mu"].coords