From cc00f598ed6d3fc9ee304da2763c8e5d4a9005fc Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Tue, 11 Mar 2025 13:16:53 +0200 Subject: [PATCH] add plot ecdf pit plot --- docs/source/gallery/sbc/plot_ecdf_pit.py | 23 +++ docs/sphinxext/gallery_generator.py | 1 + src/arviz_plots/plots/__init__.py | 2 + src/arviz_plots/plots/ecdfplot.py | 238 +++++++++++++++++++++++ 4 files changed, 264 insertions(+) create mode 100644 docs/source/gallery/sbc/plot_ecdf_pit.py create mode 100644 src/arviz_plots/plots/ecdfplot.py diff --git a/docs/source/gallery/sbc/plot_ecdf_pit.py b/docs/source/gallery/sbc/plot_ecdf_pit.py new file mode 100644 index 00000000..89ec88d2 --- /dev/null +++ b/docs/source/gallery/sbc/plot_ecdf_pit.py @@ -0,0 +1,23 @@ +""" +# PIT-ECDF difference + +faceted plot with PIT Δ-ECDF values for each variable + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_ecdf_pit` +::: +""" +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-variat") + +data = load_arviz_data("sbc") +pc = azp.plot_ecdf_pit( + data, + backend="none" # change to preferred backend +) +pc.show() diff --git a/docs/sphinxext/gallery_generator.py b/docs/sphinxext/gallery_generator.py index 85b10d9a..cb883c1c 100644 --- a/docs/sphinxext/gallery_generator.py +++ b/docs/sphinxext/gallery_generator.py @@ -20,6 +20,7 @@ "posterior_predictive_checks": "Posterior predictive checks", "prior_and_likelihood_sensitivity_checks": "Prior and likelihood sensitivity checks", "model_comparison": "Model Comparison", + "sbc": "Simulation Based Calibration", "mixed": "Mixed plots", } diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index f56e45ad..cf1262f5 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -4,6 +4,7 @@ from .compareplot import plot_compare from .convergencedistplot import plot_convergence_dist from .distplot import plot_dist +from .ecdfplot import plot_ecdf_pit from .energyplot import plot_energy from .essplot import plot_ess from .evolutionplot import plot_ess_evolution @@ -27,6 +28,7 @@ "plot_forest", "plot_trace", "plot_trace_dist", + "plot_ecdf_pit", "plot_energy", "plot_ess", "plot_ess_evolution", diff --git a/src/arviz_plots/plots/ecdfplot.py b/src/arviz_plots/plots/ecdfplot.py new file mode 100644 index 00000000..5c121123 --- /dev/null +++ b/src/arviz_plots/plots/ecdfplot.py @@ -0,0 +1,238 @@ +"""Plot PIT Δ-ECDF.""" +from copy import copy +from importlib import import_module + +import numpy as np +from arviz_base import rcParams +from arviz_base.labels import BaseLabeller +from arviz_stats.ecdf_utils import ecdf_pit + +from arviz_plots.plot_collection import PlotCollection +from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_figure_layout +from arviz_plots.visuals import ecdf_line, fill_between_y, labelled_title, labelled_x, remove_axis + + +def plot_ecdf_pit( + dt, + var_names=None, + filter_vars=None, + group="prior_sbc", + coords=None, + sample_dims=None, + ci_prob=None, + n_simulations=1000, + method="simulation", + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + pc_kwargs=None, +): + """Plot Δ-ECDF. + + Plots the Δ-ECDF, that is the difference between the expected empirical CDF an + the observed empirical ECDF. + This plot is useful to assess the goodness of fit of a model for example in SBC analysis. + It assumes the values in the DataTree are already transformed to the unit interval. + Simultaneous confidence bands are computed using the method described in [1]_. + + Parameters + ---------- + dt : DataTree + Input data + var_names : str or list of str, optional + One or more variables to be plotted. Currently only one variable is supported. + Prefix the variables by ~ when you want to exclude them from the plot. + filter_vars : {None, “like”, “regex”}, optional, default=None + If None (default), interpret var_names as the real variables names. + If “like”, interpret var_names as substrings of the real variables names. + If “regex”, interpret var_names as regular expressions on the real variables names. + group : str, optional + Which group to use. Defaults to "prior_sbc". + coords : dict, optional + Coordinates to plot. + sample_dims : str or sequence of hashable, optional + Dimensions to reduce unless mapped to an aesthetic. + Defaults to ``rcParams["data.sample_dims"]`` + ci_prob : float, optional + Indicates the probability that should be contained within the plotted credible interval. + Defaults to ``rcParams["stats.ci_prob"]`` + n_simulations : int, optional + Number of simulations to use to compute simultaneous confidence intervals when using the + `method="simulation"` ignored if method is "optimized". Defaults to 1000. + method : str, optional + Method to compute the confidence intervals. Either "simulation" or "optimized". + Defaults to "simulation". + plot_collection : PlotCollection, optional + backend : {"matplotlib", "bokeh", "plotly"}, optional + labeller : labeller, optional + aes_map : mapping of {str : sequence of str}, optional + Mapping of artists to aesthetics that should use their mapping in `plot_collection` + when plotted. Valid keys are the same as for `plot_kwargs`. + + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + + * ecdf_lines -> passed to :func:`~arviz_plots.visuals.ecdf_line` + * ci -> passed to :func:`~arviz_plots.visuals.ci_line_y` + * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` + * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + * remove_axis -> not passed anywhere, can only be ``False`` to skip calling this function + + pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection.grid` + + Returns + ------- + PlotCollection + + Examples + -------- + Rank plot for the crabs hurdle-negative-binomial dataset. + + .. plot:: + :context: close-figs + + >>> from arviz_plots import plot_ecdf_pit, style + >>> style.use("arviz-variat") + >>> from arviz_base import load_arviz_data + >>> dt = load_arviz_data('sbc') + >>> plot_ecdf_pit(dt) + + + .. minigallery:: plot_ecdf_pit + + .. [1] Säilynoja T, Bürkner PC. and Vehtari A. *Graphical test for discrete uniformity and + its applications in goodness-of-fit evaluation and multiple sample comparison*. + Statistics and Computing 32(32). (2022) https://doi.org/10.1007/s11222-022-10090-6 + """ + if ci_prob is None: + ci_prob = rcParams["stats.ci_prob"] + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + sample_dims = list(sample_dims) + if plot_kwargs is None: + plot_kwargs = {} + else: + plot_kwargs = plot_kwargs.copy() + plot_kwargs.setdefault("remove_axis", True) + if pc_kwargs is None: + pc_kwargs = {} + else: + pc_kwargs = pc_kwargs.copy() + + if backend is None: + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend + + labeller = BaseLabeller() + + distribution = process_group_variables_coords( + dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords + ) + + dt_ecdf = distribution.azstats.ecdf(dims=sample_dims, pit=True) + + # Compute envelope + dummy_vals_size = np.prod([len(distribution[dims]) for dims in sample_dims]) + dummy_vals = np.linspace(0, 1, dummy_vals_size) + x_ci, _, lower_ci, upper_ci = ecdf_pit(dummy_vals, ci_prob, method, n_simulations) + lower_ci = lower_ci - x_ci + upper_ci = upper_ci - x_ci + + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + + if plot_collection is None: + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs.setdefault("col_wrap", 5) + pc_kwargs.setdefault( + "cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims] + ) + + pc_kwargs = set_figure_layout(pc_kwargs, plot_bknd, distribution) + + plot_collection = PlotCollection.wrap( + distribution, + backend=backend, + **pc_kwargs, + ) + + if aes_map is None: + aes_map = {} + else: + aes_map = aes_map.copy() + + ## ecdf_line + ecdf_ls_kwargs = copy(plot_kwargs.get("ecdf_lines", {})) + + if ecdf_ls_kwargs is not False: + _, _, ecdf_ls_ignore = filter_aes(plot_collection, aes_map, "ecdf_lines", sample_dims) + + plot_collection.map( + ecdf_line, + "ecdf_lines", + data=dt_ecdf, + ignore_aes=ecdf_ls_ignore, + **ecdf_ls_kwargs, + ) + + ci_kwargs = copy(plot_kwargs.get("ci", {})) + _, _, ci_ignore = filter_aes(plot_collection, aes_map, "ci", sample_dims) + if ci_kwargs is not False: + ci_kwargs.setdefault("color", "black") + ci_kwargs.setdefault("alpha", 0.1) + + plot_collection.map( + fill_between_y, + "ci", + data=dt_ecdf, + x=x_ci, + y_bottom=lower_ci, + y_top=upper_ci, + ignore_aes=ci_ignore, + **ci_kwargs, + ) + + # set xlabel + _, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) + xlabel_kwargs = copy(plot_kwargs.get("xlabel", {})) + if xlabel_kwargs is not False: + if "color" not in xlabels_aes: + xlabel_kwargs.setdefault("color", "black") + + xlabel_kwargs.setdefault("text", "PIT") + + plot_collection.map( + labelled_x, + "xlabel", + ignore_aes=xlabels_ignore, + subset_info=True, + **xlabel_kwargs, + ) + + # title + title_kwargs = copy(plot_kwargs.get("title", {})) + _, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) + + if title_kwargs is not False: + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) + + if plot_kwargs.get("remove_axis", True) is not False: + plot_collection.map( + remove_axis, store_artist=False, axis="y", ignore_aes=plot_collection.aes_set + ) + + return plot_collection