From 64dc842694245f37f369f2f071d13375d1ca7eff Mon Sep 17 00:00:00 2001
From: aloctavodia <aloctavodia@gmail.com>
Date: Wed, 12 Mar 2025 16:43:04 +0200
Subject: [PATCH] fix regression bug

---
 src/arviz_plots/plots/psensequantitiesplot.py | 50 ++++++++++---------
 1 file changed, 26 insertions(+), 24 deletions(-)

diff --git a/src/arviz_plots/plots/psensequantitiesplot.py b/src/arviz_plots/plots/psensequantitiesplot.py
index 09939e5e..b81a1056 100644
--- a/src/arviz_plots/plots/psensequantitiesplot.py
+++ b/src/arviz_plots/plots/psensequantitiesplot.py
@@ -7,8 +7,8 @@
 from arviz_stats.psense import power_scale_dataset
 from xarray import concat
 
-from arviz_plots.plot_collection import PlotCollection
-from arviz_plots.plots.utils import filter_aes, process_group_variables_coords, set_figure_layout
+from arviz_plots.plot_collection import PlotCollection, process_facet_dims
+from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
 from arviz_plots.visuals import hline, labelled_title, labelled_x, line_xy, scatter_xy, set_xticks
 
 
@@ -34,10 +34,6 @@ def plot_psense_quantities(
 ):
     """Plot power scaled posterior quantities.
 
-    The posterior sensitivity is assessed by power-scaling the prior or likelihood and
-    visualizing how quantities computed from the posterior change.
-    Pareto-smoothed importance sampling is used to avoid refitting as explained in [1]_.
-
     Parameters
     ----------
     dt : DataTree
@@ -115,11 +111,6 @@ def plot_psense_quantities(
 
     .. minigallery:: plot_psense_quantities
 
-    References
-    ----------
-    .. [1] Kallioinen et al, *Detecting and diagnosing prior and likelihood sensitivity with
-        power-scaling*, Stat Comput 34, 57 (2024), https://doi.org/10.1007/s11222-023-10366-5
-
     """
     if sample_dims is None:
         sample_dims = rcParams["data.sample_dims"]
@@ -254,10 +245,22 @@ def plot_psense_quantities(
                 if dim not in sample_dims + ["component_group", "alpha"]
             ],
         )
+        figsize = pc_kwargs["plot_grid_kws"].get("figsize", None)
+        figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches")
+        col_dims = pc_kwargs["cols"]
+        row_dims = pc_kwargs["rows"]
+        if figsize is None:
+            figsize = plot_bknd.scale_fig_size(
+                figsize,
+                rows=process_facet_dims(ds_quantities, row_dims)[0],
+                cols=process_facet_dims(ds_quantities, col_dims)[0],
+                figsize_units=figsize_units,
+            )
+            figsize_units = "dots"
+        pc_kwargs["plot_grid_kws"]["figsize"] = figsize
+        pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units
 
-        pc_kwargs = set_figure_layout(pc_kwargs, plot_bknd, ds_quantities)
-
-        plot_collection = PlotCollection.wrap(
+        plot_collection = PlotCollection.grid(
             ds_quantities,
             backend=backend,
             **pc_kwargs,
@@ -370,7 +373,7 @@ def plot_psense_quantities(
 
     # set xlabel
     _, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
-    xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
+    xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy()
     if xlabel_kwargs is not False:
         if "color" not in xlabels_aes:
             xlabel_kwargs.setdefault("color", "black")
@@ -389,14 +392,13 @@ def plot_psense_quantities(
     title_kwargs = copy(plot_kwargs.get("title", {}))
     _, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims)
 
-    if title_kwargs is not False:
-        plot_collection.map(
-            labelled_title,
-            "title",
-            ignore_aes=title_ignore,
-            subset_info=True,
-            labeller=labeller,
-            **title_kwargs,
-        )
+    plot_collection.map(
+        labelled_title,
+        "title",
+        ignore_aes=title_ignore,
+        subset_info=True,
+        labeller=labeller,
+        **title_kwargs,
+    )
 
     return plot_collection