From 4a3ed6198c017213c02ff902190e4538310962ad Mon Sep 17 00:00:00 2001 From: PiyushPanwarFST Date: Sat, 1 Mar 2025 14:55:15 +0530 Subject: [PATCH] Add plot_bf() function with tests and documentation. Implemented the plot_bf() function in arviz-plots. Added test cases to validate the functionality of plot_bf(). Introduced a new text_only argument in the add_legend() function. Included plot_bf() in the documentation. Incorporated feedback recieve so far. Resolved all the linter(pylint) errors. Signed-off-by: PiyushPanwarFST --- docs/source/api/plots.rst | 1 + .../gallery/model_comparison/plot_bf.py | 23 ++ src/arviz_plots/plot_collection.py | 27 ++- src/arviz_plots/plots/__init__.py | 2 + src/arviz_plots/plots/bfplot.py | 210 ++++++++++++++++++ tests/test_plots.py | 14 +- 6 files changed, 273 insertions(+), 4 deletions(-) create mode 100644 docs/source/gallery/model_comparison/plot_bf.py create mode 100644 src/arviz_plots/plots/bfplot.py diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index a17609e3..b488ec48 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -17,6 +17,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at .. autosummary:: :toctree: generated/ + plot_bf plot_compare plot_convergence_dist plot_dist diff --git a/docs/source/gallery/model_comparison/plot_bf.py b/docs/source/gallery/model_comparison/plot_bf.py new file mode 100644 index 00000000..88f05a91 --- /dev/null +++ b/docs/source/gallery/model_comparison/plot_bf.py @@ -0,0 +1,23 @@ +""" +# Bayes_factor +Compute Bayes factor using Savage–Dickey ratio +--- +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_bf` +::: +""" +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-variat") + +data = load_arviz_data("centered_eight") + +pc = azp.plot_bf( + data, + backend="none", + var_name="mu" +) + +pc.show() \ No newline at end of file diff --git a/src/arviz_plots/plot_collection.py b/src/arviz_plots/plot_collection.py index 3ad6ff9e..66cf05e0 100644 --- a/src/arviz_plots/plot_collection.py +++ b/src/arviz_plots/plot_collection.py @@ -1047,7 +1047,16 @@ def map( aux_artist = np.squeeze(aux_artist) self.viz[var_name][fun_label].loc[sel] = aux_artist - def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=None, **kwargs): + def add_legend( + self, + dim, + var_name=None, + aes=None, + artist_kwargs=None, + title=None, + text_only=False, + **kwargs, + ): """Add a legend for the given artist/aesthetic to the plot. Warnings @@ -1073,6 +1082,8 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non generate the miniatures in the legend. title : str, optional Legend title. Defaults to `dim`. + text_only : bool, optional + If True, creates a text-only legend without graphical markers. **kwargs : mapping, optional Keyword arguments passed to the backend function that generates the legend. @@ -1098,18 +1109,28 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non if isinstance(aes, str): aes = [aes] aes_ds = aes_ds[aes] + label_list = aes_ds[dim].values kwarg_list = [ {k: v.item() for k, v in aes_ds.sel({dim: coord}).items()} for coord in label_list ] + for kwarg_dict in kwarg_list: kwarg_dict.pop("overlay", None) + if text_only: + kwarg_dict.pop("color", None) + plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots") + + legend_title = None if text_only else title + return plot_bknd.legend( self.viz["chart"].item(), kwarg_list, label_list, - title=title, - artist_kwargs=artist_kwargs, + title=legend_title, + artist_kwargs={"linestyle": "none", "linewidth": 0, "color": "none"} + if text_only + else artist_kwargs, **kwargs, ) diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index dfda262a..b2be5aac 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -1,5 +1,6 @@ """Batteries-included ArviZ plots.""" +from .bfplot import plot_bf from .compareplot import plot_compare from .convergencedistplot import plot_convergence_dist from .distplot import plot_dist @@ -17,6 +18,7 @@ from .traceplot import plot_trace __all__ = [ + "plot_bf", "plot_compare", "plot_convergence_dist", "plot_dist", diff --git a/src/arviz_plots/plots/bfplot.py b/src/arviz_plots/plots/bfplot.py new file mode 100644 index 00000000..12990bb8 --- /dev/null +++ b/src/arviz_plots/plots/bfplot.py @@ -0,0 +1,210 @@ +"""Contain functions for Bayes Factor plotting.""" + +from copy import copy +from importlib import import_module + +import xarray as xr +from arviz_base import extract, rcParams +from arviz_stats.bayes_factor import bayes_factor + +from arviz_plots.plot_collection import PlotCollection +from arviz_plots.plots.distplot import plot_dist +from arviz_plots.plots.utils import filter_aes +from arviz_plots.visuals import vline + + +def plot_bf( + dt, + var_name, + ref_val=0, + kind=None, + sample_dims=None, + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + stats_kwargs=None, + pc_kwargs=None, +): + r"""Approximated Bayes Factor for comparing hypothesis of two nested models. + + The Bayes factor is estimated by comparing a model (H1) against a model + in which the parameter of interest has been restricted to be a point-null (H0) + This computation assumes the models are nested and thus H0 is a special case of H1. + + Parameters + ---------- + dt : DataTree or dict of {str : DataTree} + Input data. In case of dictionary input, the keys are taken to be model names. + In such cases, a dimension "model" is generated and can be used to map to aesthetics. + var_names : str or list of str, optional + One or more variables to be plotted. + Prefix the variables by ~ when you want to exclude them from the plot. + ref_val : int or float, default 0 + Reference (point-null) value for Bayes factor estimation. + kind : {"kde", "hist", "dot", "ecdf"}, optional + How to represent the marginal density. + Defaults to ``rcParams["plot.density_kind"]`` + sample_dims : str or sequence of hashable, optional + Dimensions to reduce unless mapped to an aesthetic. + Defaults to ``rcParams["data.sample_dims"]`` + plot_collection : PlotCollection, optional + backend : {"matplotlib", "bokeh"}, optional + labeller : labeller, optional + aes_map : mapping of {str : sequence of str}, optional + Mapping of artists to aesthetics that should use their mapping in `plot_collection` + when plotted. Valid keys are the same as for `plot_kwargs`. + With a single model, no aesthetic mappings are generated by default, + each variable+coord combination gets a :term:`plot` but they all look the same, + unless there are user provided aesthetic mappings. + With multiple models, ``plot_dist`` maps "color" and "y" to the "model" dimension. + By default, all aesthetics but "y" are mapped to the density representation, + and if multiple models are present, "color" and "y" are mapped to the + credible interval and the point estimate. + When "point_estimate" key is provided but "point_estimate_text" isn't, + the values assigned to the first are also used for the second. + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + * One of "kde", "ecdf", "dot" or "hist", matching the `kind` argument. + * "kde" -> passed to :func:`~arviz_plots.visuals.line_xy` + * "ecdf" -> passed to :func:`~arviz_plots.visuals.ecdf_line` + * "hist" -> passed to :func: `~arviz_plots.visuals.hist` + * credible_interval -> passed to :func:`~arviz_plots.visuals.line_x` + * point_estimate -> passed to :func:`~arviz_plots.visuals.scatter_x` + * point_estimate_text -> passed to :func:`~arviz_plots.visuals.point_estimate_text` + * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function + stats_kwargs : mapping, optional + Valid keys are: + * density -> passed to kde, ecdf, ... + * credible_interval -> passed to eti or hdi + * point_estimate -> passed to mean, median or mode + pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection.wrap` + """ + if kind is None: + kind = rcParams["plot.density_kind"] + if plot_kwargs is None: + plot_kwargs = {} + else: + plot_kwargs = plot_kwargs.copy() + if pc_kwargs is None: + pc_kwargs = {} + else: + pc_kwargs = pc_kwargs.copy() + if aes_map is None: + aes_map = {} + else: + aes_map = aes_map.copy() + + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + sample_dims = list(sample_dims) + if not isinstance(plot_kwargs, dict): + plot_kwargs = {} + + ref_line_kwargs = copy(plot_kwargs.get("ref_line", {})) + if ref_line_kwargs is False: + raise ValueError( + "plot_kwargs['ref_line'] can't be False, use ref_val=False to remove this element" + ) + + if backend is None: + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + + ds_prior = extract(dt, group="prior", var_names=var_name, keep_dataset=True) + ds_posterior = extract(dt, group="posterior", var_names=var_name, keep_dataset=True) + + distribution = xr.concat([ds_prior, ds_posterior], dim="Groups").assign_coords( + {"Groups": ["prior", "posterior"]} + ) + + if len(sample_dims) > 1: + # sample dims will have been stacked and renamed by `power_scale_dataset` + sample_dims = ["sample"] + + # Compute Bayes Factor using the bayes_factor function + bf, _ = bayes_factor(dt, var_name, ref_val, return_ref_vals=True) + + bf_values = xr.DataArray( + [f"BF10: {bf['BF10']:.2f}"], + dims=["BF_type"], + coords={"BF_type": ["BF10"]}, + ) + + if plot_collection is None: + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs["aes"].setdefault("color", ["Groups"]) + + figsize = pc_kwargs["plot_grid_kws"].get("figsize", None) + figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches") + if figsize is None: + figsize = plot_bknd.scale_fig_size( + figsize, + rows=1, + cols=1, + figsize_units=figsize_units, + ) + figsize_units = "dots" + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units + + plot_collection = PlotCollection.grid( + distribution, + backend=backend, + **pc_kwargs, + ) + + plot_collection.aes["BF_type"] = bf_values + + plot_kwargs.setdefault("credible_interval", False) + plot_kwargs.setdefault("point_estimate", False) + plot_kwargs.setdefault("point_estimate_text", False) + + plot_collection = plot_dist( + distribution, + var_names=var_name, + group=None, + coords=None, + sample_dims=sample_dims, + kind=kind, + point_estimate=None, + ci_kind=None, + ci_prob=None, + plot_collection=plot_collection, + backend=backend, + labeller=labeller, + plot_kwargs=plot_kwargs, + stats_kwargs=stats_kwargs, + pc_kwargs=pc_kwargs, + ) + + if ref_val is not False: + ref_dt = xr.Dataset({var_name: xr.DataArray([ref_val])}) + _, ref_aes, ref_ignore = filter_aes(plot_collection, aes_map, "ref_line", sample_dims) + if "color" not in ref_aes: + ref_line_kwargs.setdefault("color", "black") + if "linestyle" not in ref_aes: + default_linestyle = plot_bknd.get_default_aes("linestyle", 2, {})[1] + ref_line_kwargs.setdefault("linestyle", default_linestyle) + if "alpha" not in ref_aes: + ref_line_kwargs.setdefault("alpha", 0.5) + + plot_collection.map( + vline, "ref_line", data=ref_dt, ignore_aes=ref_ignore, **ref_line_kwargs + ) + + if backend == "matplotlib": ## remove this when we have a better way to handle legends + plot_collection.add_legend("Groups", loc="upper right", fontsize=10) + plot_collection.add_legend("BF_type", loc="upper left", fontsize=10, text_only=True) + + return plot_collection diff --git a/tests/test_plots.py b/tests/test_plots.py index e06a6bfd..15348eeb 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -3,10 +3,11 @@ import numpy as np import pandas as pd import pytest -from arviz_base import from_dict +from arviz_base import from_dict, load_arviz_data from scipy.stats import halfnorm, norm from arviz_plots import ( + plot_bf, plot_compare, plot_dist, plot_ess, @@ -461,3 +462,14 @@ def test_plot_psense_dist_sample(self, datatree_sample, backend): assert "component_group" in pc.viz["mu"]["credible_interval"].dims assert "alpha" in pc.viz["mu"]["credible_interval"].dims assert "hierarchy" in pc.viz["theta"].dims + + def test_plot_bf(self, datatree, backend): + # The current genrate_base_data() function lacks a "prior" group, + # so we're manually loading the data. + datatree = load_arviz_data("centered_eight") + pc = plot_bf(datatree, var_name="mu", backend=backend) + assert "chart" in pc.viz.data_vars + assert "plot" in pc.viz.data_vars + assert "row" in pc.viz.data_vars + assert "col" in pc.viz.data_vars + assert "Groups" in pc.viz["mu"].coords