Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consistence use of data_pairs, remove default markers from pava #152

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/arviz_plots/plots/pavacalibrationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def plot_ppc_pava(
defaults to 1000.
ci_prob : float, optional
Probability for the credible interval. Defaults to ``rcParams["stats.ci_prob"]``.
data_pairs : tuple, optional
Tuple of posterior predictive data and observed data variable names.
If None, it will assume that the observed data and the posterior
predictive data have the same variable name.
data_pairs : dict, optional
Dictionary of keys prior/posterior predictive data and values observed data variable names.
If None, it will assume that the observed data and the predictive data have
the same variable name.
num_samples : int, optional
Number of samples to use for the plot. Defaults to 100.
var_names : str or list of str, optional
Expand Down Expand Up @@ -84,6 +84,8 @@ def plot_ppc_pava(
* ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`

markers defaults to False, no markers are plotted.
Pass an (empty) mapping to plot markers.

pc_kwargs : mapping
Passed to :class:`arviz_plots.PlotCollection.grid`
Expand Down Expand Up @@ -138,10 +140,12 @@ def plot_ppc_pava(

labeller = BaseLabeller()

plot_kwargs.setdefault("markers", False)

if data_pairs is None:
data_pairs = (var_names, var_names)
data_pairs = {var_names: var_names}

ds_calibration = isotonic_fit(dt, var_names, n_bootstaps, ci_prob)
ds_calibration = isotonic_fit(dt, data_pairs, n_bootstaps, ci_prob)

plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
colors = plot_bknd.get_default_aes("color", 1, {})
Expand Down Expand Up @@ -196,7 +200,6 @@ def plot_ppc_pava(
"reference_line",
data=ds_calibration,
x=ds_calibration.sel(plot_axis="x"),
# y=ds_calibration.sel(plot_axis="y"),
ignore_aes=reference_ls_ignore,
**reference_ls_kwargs,
)
Expand All @@ -205,7 +208,7 @@ def plot_ppc_pava(
calibration_ms_kwargs = copy(plot_kwargs.get("markers", {}))

if calibration_ms_kwargs is not False:
_, _, calibration_ms_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)
_, _, calibration_ms_ignore = filter_aes(plot_collection, aes_map, "markers", sample_dims)
calibration_ms_kwargs.setdefault("color", colors[0])
calibration_ms_kwargs.setdefault("marker", markers[6])

Expand Down
14 changes: 12 additions & 2 deletions src/arviz_plots/plots/ppcdistplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

def plot_ppc_dist(
dt,
data_pairs=None,
var_names=None,
filter_vars=None,
group="posterior_predictive",
Expand All @@ -37,6 +38,10 @@ def plot_ppc_dist(
----------
dt : DataTree
Input data
data_pairs : dict, optional
Dictionary of keys prior/posterior predictive data and values observed data variable names.
If None, it will assume that the observed data and the predictive data have
the same variable name.
var_names : str or list of str, optional
One or more variables to be plotted.
Prefix the variables by ~ when you want to exclude them from the plot.
Expand Down Expand Up @@ -155,12 +160,17 @@ def plot_ppc_dist(
dims for dims in dt.posterior_predictive.dims if dims not in sample_dims
]

if data_pairs is None:
data_pairs = (var_names, var_names)
else:
data_pairs = (list(data_pairs.keys()), list(data_pairs.values()))

predictive_dist = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
dt, group=group, var_names=data_pairs[0], filter_vars=filter_vars, coords=coords
)

observed_dist = process_group_variables_coords(
dt, group="observed_data", var_names=var_names, filter_vars=filter_vars, coords=coords
dt, group="observed_data", var_names=data_pairs[1], filter_vars=filter_vars, coords=coords
)

predictive_types = [
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/ppcrootogramplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def plot_ppc_rootogram(
Scale for the y-axis. Defaults to "sqrt", pass "linear" for linear scale.
Currently only "matplotlib" backend is supported. For "bokeh" and "plotly"
the y-axis is linear.
data_pairs : tuple, optional
Tuple of prior/posterior predictive data and observed data variable names.
data_pairs : dict, optional
Dictionary of keys prior/posterior predictive data and values observed data variable names.
If None, it will assume that the observed data and the predictive data have
the same variable name.
var_names : str or list of str, optional
Expand Down