Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

plot_converge_dist: Add grouped argument #182

Merged
merged 2 commits into from
Mar 18, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 81 additions & 18 deletions src/arviz_plots/plots/convergencedistplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def plot_convergence_dist(
dt,
diagnostics=None,
grouped=True,
ref_line=True,
var_names=None,
filter_vars=None,
Expand All @@ -35,6 +36,10 @@ def plot_convergence_dist(
):
"""Plot the distribution of convergence diagnostics (ESS and/or R-hat).

By default all variables are grouped together and one plot per diagnostic is created.
If you are interested in representing individual (multidimensional variables) pass them
in `var_names`.

Information on how the diagnostics are computed can be found in [1]_.

Parameters
Expand All @@ -46,6 +51,10 @@ def plot_convergence_dist(
Valid diagnostics are "rhat_rank", "rhat_folded", "rhat_z_scale", "rhat_split",
"rhat_identity", "ess_bulk", "ess_tail", "ess_mean", "ess_sd", "ess_quantile",
"ess_local", "ess_median", "ess_mad", "ess_z_scale", "ess_folded" and "ess_identity".
grouped: bool, optional
Whether to plot all variables listed in ``var_names`` together (True)
or separately (False). Defaults to True.
If False, all variables listed in ``var_names`` must be multidimensional.
ref_line : bool, default True
Whether to plot a reference line for the recommended value of each diagnostic.
var_names : str or list of str, optional
Expand Down Expand Up @@ -128,6 +137,17 @@ def plot_convergence_dist(
>>> ]
>>> )

Select two variables and plot them separately

.. plot::
:context: close-figs

>>> plot_convergence_dist(
>>> radon,
>>> var_names=["za_county", "a"],
>>> grouped=False,
>>> )

.. minigallery:: plot_convergence_dist


Expand Down Expand Up @@ -178,22 +198,35 @@ def plot_convergence_dist(
dt = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
)
distribution = _compute_diagnostics(dt, diagnostics, sample_dims, grouped)

distribution = _compute_diagnostics(dt, diagnostics, sample_dims)
if plot_collection is None:
pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy()

pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
if grouped is False:
pc_kwargs.setdefault("col_wrap", len(diagnostics))

if grouped:
plot_dist_sample_dims = ["label"]
plot_dist_var_names = None
else:
plot_dist_sample_dims = [dim for dim in dt.dims if dim not in sample_dims]
plot_dist_var_names = var_names

plot_kwargs.setdefault("credible_interval", False)
plot_kwargs.setdefault("point_estimate", False)
plot_kwargs.setdefault("point_estimate_text", False)

plot_collection = plot_dist(
distribution,
var_names=None,
var_names=plot_dist_var_names,
filter_vars=None,
group=None,
coords=None,
# _compute_diagnostics returns output with only this dimension
# it is the one we want reduced in plot_dist to show the distributions
sample_dims=["label"],
sample_dims=plot_dist_sample_dims,
kind=kind,
point_estimate=point_estimate,
ci_kind=ci_kind,
Expand All @@ -220,38 +253,68 @@ def plot_convergence_dist(
ess_ref = dt.sizes["chain"] * 100
# is this valid for all r_hat methods? Do we want to correct for multiple comparisons?
r_hat_ref = 1.01
ref_ds = xr.Dataset(
{
diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref
for diagnostic in distribution.data_vars
}
)
if grouped:
ref_ds = xr.Dataset(
{
diagnostic: ess_ref if "ess" in diagnostic else r_hat_ref
for diagnostic in diagnostics
}
)
else:
ref_values = [
ess_ref if "ess" in diagnostic else r_hat_ref for diagnostic in diagnostics
]
ref_ds = xr.Dataset(
{
var: xr.DataArray(ref_values, dims="diagnostic")
for var in distribution.data_vars
},
coords={"diagnostic": diagnostics},
)

plot_collection.map(
vline, "ref_line", data=ref_ds, ignore_aes=ref_ignore, **ref_line_kwargs
)

return plot_collection


def _compute_diagnostics(dt, diagnostics, sample_dims):
diagnostic_values = {}
def _compute_diagnostics(dt, diagnostics, sample_dims, grouped):
diagnostic_dict = {}
for diagnostic in diagnostics:
if "ess" in diagnostic:
prob = None
method = diagnostic.split("_", 1)[1].split("(", 1)[0]
if method in {"tail", "quantile", "local"} and "(" in diagnostic:
prob = [float(p) for p in diagnostic.split("(", 1)[1].rstrip(")").split(", ")]
diagnostic_values[diagnostic] = dt.azstats.ess(
method=method, prob=prob, dims=sample_dims
).to_stacked_array("label", sample_dims=[])

diagnostic_dt = dt.azstats.ess(method=method, prob=prob, dims=sample_dims)
if grouped:
diagnostic_dict[diagnostic] = diagnostic_dt.to_stacked_array(
"label", sample_dims=[]
)
else:
diagnostic_dict[diagnostic] = diagnostic_dt

elif "rhat" in diagnostic:
kwargs = {"dims": sample_dims}
if diagnostic != "rhat":
method = diagnostic.split("_", 1)[1]
kwargs.update({"method": method})
diagnostic_values[diagnostic] = dt.azstats.rhat(**kwargs).to_stacked_array(
"label", sample_dims=[]
)

diagnostic_dt = dt.azstats.rhat(**kwargs)
if grouped:
diagnostic_dict[diagnostic] = diagnostic_dt.to_stacked_array(
"label", sample_dims=[]
)
else:
diagnostic_dict[diagnostic] = diagnostic_dt
else:
warnings.warn(f"{diagnostic} is not recognized as a valid diagnostic")
return xr.Dataset(diagnostic_values)

if grouped:
return xr.Dataset(diagnostic_dict)

return xr.concat(list(diagnostic_dict.values()), dim="diagnostic").assign_coords(
{"diagnostic": list(diagnostic_dict.keys())}
)