diff --git a/src/arviz_plots/plots/psensequantitiesplot.py b/src/arviz_plots/plots/psensequantitiesplot.py index 09939e5..b81a105 100644 --- a/src/arviz_plots/plots/psensequantitiesplot.py +++ b/src/arviz_plots/plots/psensequantitiesplot.py @@ -7,8 +7,8 @@ from arviz_stats.psense import power_scale_dataset from xarray import concat -from arviz_plots.plot_collection import PlotCollection -from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_figure_layout +from arviz_plots.plot_collection import PlotCollection, process_facet_dims +from arviz_plots.plots.utils import filter_aes, process_group_variables_coords from arviz_plots.visuals import hline, labelled_title, labelled_x, line_xy, scatter_xy, set_xticks @@ -34,10 +34,6 @@ def plot_psense_quantities( ): """Plot power scaled posterior quantities. - The posterior sensitivity is assessed by power-scaling the prior or likelihood and - visualizing how quantities computed from the posterior change. - Pareto-smoothed importance sampling is used to avoid refitting as explained in [1]_. - Parameters ---------- dt : DataTree @@ -115,11 +111,6 @@ def plot_psense_quantities( .. minigallery:: plot_psense_quantities - References - ---------- - .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with - power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5 - """ if sample_dims is None: sample_dims = rcParams["data.sample_dims"] @@ -254,10 +245,22 @@ def plot_psense_quantities( if dim not in sample_dims + ["component_group", "alpha"] ], ) + figsize = pc_kwargs["plot_grid_kws"].get("figsize", None) + figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches") + col_dims = pc_kwargs["cols"] + row_dims = pc_kwargs["rows"] + if figsize is None: + figsize = plot_bknd.scale_fig_size( + figsize, + rows=process_facet_dims(ds_quantities, row_dims)[0], + cols=process_facet_dims(ds_quantities, col_dims)[0], + figsize_units=figsize_units, + ) + figsize_units = "dots" + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units - pc_kwargs = set_figure_layout(pc_kwargs, plot_bknd, ds_quantities) - - plot_collection = PlotCollection.wrap( + plot_collection = PlotCollection.grid( ds_quantities, backend=backend, **pc_kwargs, @@ -370,7 +373,7 @@ def plot_psense_quantities( # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) - xlabel_kwargs = copy(plot_kwargs.get("xlabel", {})) + xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy() if xlabel_kwargs is not False: if "color" not in xlabels_aes: xlabel_kwargs.setdefault("color", "black") @@ -389,14 +392,13 @@ def plot_psense_quantities( title_kwargs = copy(plot_kwargs.get("title", {})) _, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) - if title_kwargs is not False: - plot_collection.map( - labelled_title, - "title", - ignore_aes=title_ignore, - subset_info=True, - labeller=labeller, - **title_kwargs, - ) + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) return plot_collection