diff --git a/src/arviz_plots/plots/convergencedistplot.py b/src/arviz_plots/plots/convergencedistplot.py index cbc07c4..f07734a 100644 --- a/src/arviz_plots/plots/convergencedistplot.py +++ b/src/arviz_plots/plots/convergencedistplot.py @@ -15,6 +15,7 @@ def plot_convergence_dist( dt, diagnostics=None, + grouped=True, ref_line=True, var_names=None, filter_vars=None, @@ -35,6 +36,10 @@ def plot_convergence_dist( ): """Plot the distribution of convergence diagnostics (ESS and/or R-hat). + By default all variables are grouped together and one plot per diagnostic is created. + If you are interested in representing individual (multidimensional variables) pass them + in `var_names`. + Information on how the diagnostics are computed can be found in [1]_. Parameters @@ -46,6 +51,10 @@ def plot_convergence_dist( Valid diagnostics are "rhat_rank", "rhat_folded", "rhat_z_scale", "rhat_split", "rhat_identity", "ess_bulk", "ess_tail", "ess_mean", "ess_sd", "ess_quantile", "ess_local", "ess_median", "ess_mad", "ess_z_scale", "ess_folded" and "ess_identity". + grouped: bool, optional + Whether to plot all variables listed in ``var_names`` together (True) + or separately (False). Defaults to True. + If False, all variables listed in ``var_names`` must be multidimensional. ref_line : bool, default True Whether to plot a reference line for the recommended value of each diagnostic. var_names : str or list of str, optional @@ -128,6 +137,17 @@ def plot_convergence_dist( >>> ] >>> ) + Select two variables and plot them separately + + .. plot:: + :context: close-figs + + >>> plot_convergence_dist( + >>> radon, + >>> var_names=["za_county", "a"], + >>> grouped=False, + >>> ) + .. minigallery:: plot_convergence_dist @@ -178,8 +198,21 @@ def plot_convergence_dist( dt = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) + distribution = _compute_diagnostics(dt, diagnostics, sample_dims, grouped) - distribution = _compute_diagnostics(dt, diagnostics, sample_dims) + if plot_collection is None: + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + if grouped is False: + pc_kwargs.setdefault("col_wrap", len(diagnostics)) + + if grouped: + plot_dist_sample_dims = ["label"] + plot_dist_var_names = None + else: + plot_dist_sample_dims = [dim for dim in dt.dims if dim not in sample_dims] + plot_dist_var_names = var_names plot_kwargs.setdefault("credible_interval", False) plot_kwargs.setdefault("point_estimate", False) @@ -187,13 +220,13 @@ def plot_convergence_dist( plot_collection = plot_dist( distribution, - var_names=None, + var_names=plot_dist_var_names, filter_vars=None, group=None, coords=None, # _compute_diagnostics returns output with only this dimension # it is the one we want reduced in plot_dist to show the distributions - sample_dims=["label"], + sample_dims=plot_dist_sample_dims, kind=kind, point_estimate=point_estimate, ci_kind=ci_kind, @@ -220,12 +253,25 @@ def plot_convergence_dist( ess_ref = dt.sizes["chain"] * 100 # is this valid for all r_hat methods? Do we want to correct for multiple comparisons? r_hat_ref = 1.01 - ref_ds = xr.Dataset( - { - diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref - for diagnostic in distribution.data_vars - } - ) + if grouped: + ref_ds = xr.Dataset( + { + diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref + for diagnostic in diagnostics + } + ) + else: + ref_values = [ + ess_ref if "ess" in diagnostic else r_hat_ref for diagnostic in diagnostics + ] + ref_ds = xr.Dataset( + { + var: xr.DataArray(ref_values, dims="diagnostic") + for var in distribution.data_vars + }, + coords={"diagnostic": diagnostics}, + ) + plot_collection.map( vline, "ref_line", data=ref_ds, ignore_aes=ref_ignore, **ref_line_kwargs ) @@ -233,25 +279,42 @@ def plot_convergence_dist( return plot_collection -def _compute_diagnostics(dt, diagnostics, sample_dims): - diagnostic_values = {} +def _compute_diagnostics(dt, diagnostics, sample_dims, grouped): + diagnostic_dict = {} for diagnostic in diagnostics: if "ess" in diagnostic: prob = None method = diagnostic.split("_", 1)[1].split("(", 1)[0] if method in {"tail", "quantile", "local"} and "(" in diagnostic: prob = [float(p) for p in diagnostic.split("(", 1)[1].rstrip(")").split(", ")] - diagnostic_values[diagnostic] = dt.azstats.ess( - method=method, prob=prob, dims=sample_dims - ).to_stacked_array("label", sample_dims=[]) + + diagnostic_dt = dt.azstats.ess(method=method, prob=prob, dims=sample_dims) + if grouped: + diagnostic_dict[diagnostic] = diagnostic_dt.to_stacked_array( + "label", sample_dims=[] + ) + else: + diagnostic_dict[diagnostic] = diagnostic_dt + elif "rhat" in diagnostic: kwargs = {"dims": sample_dims} if diagnostic != "rhat": method = diagnostic.split("_", 1)[1] kwargs.update({"method": method}) - diagnostic_values[diagnostic] = dt.azstats.rhat(**kwargs).to_stacked_array( - "label", sample_dims=[] - ) + + diagnostic_dt = dt.azstats.rhat(**kwargs) + if grouped: + diagnostic_dict[diagnostic] = diagnostic_dt.to_stacked_array( + "label", sample_dims=[] + ) + else: + diagnostic_dict[diagnostic] = diagnostic_dt else: warnings.warn(f"{diagnostic} is not recognized as a valid diagnostic") - return xr.Dataset(diagnostic_values) + + if grouped: + return xr.Dataset(diagnostic_dict) + + return xr.concat(list(diagnostic_dict.values()), dim="diagnostic").assign_coords( + {"diagnostic": list(diagnostic_dict.keys())} + )