diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index a17609e3..476e8e52 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -20,6 +20,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at plot_compare plot_convergence_dist plot_dist + plot_energy plot_ess plot_ess_evolution plot_forest diff --git a/src/arviz_plots/backend/bokeh/__init__.py b/src/arviz_plots/backend/bokeh/__init__.py index 279097a8..7452c317 100644 --- a/src/arviz_plots/backend/bokeh/__init__.py +++ b/src/arviz_plots/backend/bokeh/__init__.py @@ -517,5 +517,13 @@ def remove_axis(target, axis="y"): def set_y_scale(target, scale): - """Interface to matplotlib for setting the y scale of a plot.""" + """Interface to bokeh for setting the y scale of a plot.""" target.set_yscale(scale) + + +def grid(target, axis, color): + """Interface to bokeh for setting a grid in any axis.""" + if axis in ["y", "both"]: + target.ygrid.grid_line_color = color + if axis in ["x", "both"]: + target.xgrid.grid_line_color = color diff --git a/src/arviz_plots/backend/matplotlib/__init__.py b/src/arviz_plots/backend/matplotlib/__init__.py index c8dc6e5e..f539b888 100644 --- a/src/arviz_plots/backend/matplotlib/__init__.py +++ b/src/arviz_plots/backend/matplotlib/__init__.py @@ -497,3 +497,8 @@ def remove_axis(target, axis="y"): def set_y_scale(target, scale): """Interface to matplotlib for setting the y scale of a plot.""" target.set_yscale(scale) + + +def grid(target, axis, color): + """Interface to matplotlib for setting a grid in any axis.""" + target.grid(axis=axis, color=color) diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index a3f06bca..548c0ba1 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -595,3 +595,12 @@ def remove_axis(target, axis="y"): target.update_yaxes(visible=False) if axis in ("x", "both"): target.update_xaxes(visible=False) + + +def grid(target, axis="both", color=unset, **artist_kws): + """Interface to plotly for setting a grid in any axis.""" + kwargs = {"showgrid": True, "gridcolor": color} + if axis in ["y", "both"]: + target.update_yaxes(_filter_kwargs(kwargs, artist_kws)) + if axis in ["x", "both"]: + target.update_xaxes(_filter_kwargs(kwargs, artist_kws)) diff --git a/src/arviz_plots/plots/ppcrootogramplot.py b/src/arviz_plots/plots/ppcrootogramplot.py index 2b87a0bb..5d7ad7cf 100644 --- a/src/arviz_plots/plots/ppcrootogramplot.py +++ b/src/arviz_plots/plots/ppcrootogramplot.py @@ -10,6 +10,7 @@ from arviz_plots.plots.utils import filter_aes, process_group_variables_coords from arviz_plots.visuals import ( ci_line_y, + grid, labelled_title, labelled_x, labelled_y, @@ -84,6 +85,7 @@ def plot_ppc_rootogram( * ci -> passed to :func:`~arviz_plots.visuals.ci_line_y` * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` + * grid -> passed to :func:`~arviz_plots.visuals.grid` * title -> passed to :func:`~arviz_plots.visuals.labelled_title` pc_kwargs : mapping @@ -263,6 +265,21 @@ def plot_ppc_rootogram( **observed_ms_kwargs, ) + ## grid + grid_kwargs = copy(plot_kwargs.get("grid", {})) + + if grid_kwargs is not False: + _, _, grid_ignore = filter_aes(plot_collection, aes_map, "grid", sample_dims) + grid_kwargs.setdefault("color", "#cccccc") + grid_kwargs.setdefault("axis", "y") + + plot_collection.map( + grid, + "grid", + ignore_aes=grid_ignore, + **grid_kwargs, + ) + # set xlabel _, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) xlabel_kwargs = copy(plot_kwargs.get("xlabel", {})) diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 7435df7c..ab3c4447 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -304,6 +304,12 @@ def set_xticks(da, target, backend, values, labels, **kwargs): def set_y_scale(da, target, backend, scale, **kwargs): - """Dispatch to ``remove_axis`` function in backend.""" + """Set scale for y-axis.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") plot_backend.set_y_scale(target, scale, **kwargs) + + +def grid(da, target, backend, **kwargs): + """Dispatch to ``remove_axis`` function in backend.""" + plot_backend = import_module(f"arviz_plots.backend.{backend}") + plot_backend.grid(target, **kwargs) diff --git a/tests/test_plots.py b/tests/test_plots.py index e06a6bfd..bba1274e 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -7,8 +7,10 @@ from scipy.stats import halfnorm, norm from arviz_plots import ( + plot_bf, plot_compare, plot_dist, + plot_energy, plot_ess, plot_ess_evolution, plot_forest, @@ -46,6 +48,10 @@ def generate_base_data(seed=31): prior_predictive = rng.normal(size=(1, 100, 7)) posterior_predictive = rng.normal(size=(4, 100, 7)) diverging = rng.choice([True, False], size=(4, 100), p=[0.1, 0.9]) + mu_prior_sampled = rng.normal(size=(1, 500)) + tau_prior_sampled = np.exp(rng.normal(size=(1, 500))) + theta_prior_sampled = rng.normal(size=(1, 500, 8)) + energy = rng.normal(loc=50, scale=10, size=(4, 100)) return { "posterior": {"mu": mu, "theta": theta, "tau": tau}, @@ -54,7 +60,8 @@ def generate_base_data(seed=31): "log_prior": {"mu": mu_prior, "theta": theta_prior, "tau": tau_prior}, "prior_predictive": {"y": prior_predictive}, "posterior_predictive": {"y": posterior_predictive}, - "sample_stats": {"diverging": diverging}, + "sample_stats": {"diverging": diverging, "energy": energy}, + "prior": {"mu": mu_prior_sampled, "theta": theta_prior_sampled, "tau": tau_prior_sampled}, } @@ -461,3 +468,43 @@ def test_plot_psense_dist_sample(self, datatree_sample, backend): assert "component_group" in pc.viz["mu"]["credible_interval"].dims assert "alpha" in pc.viz["mu"]["credible_interval"].dims assert "hierarchy" in pc.viz["theta"].dims + + def test_plot_bf(self, datatree, backend): + pc = plot_bf(datatree, var_name="mu", backend=backend) + assert "chart" in pc.viz.data_vars + assert "plot" in pc.viz.data_vars + assert "row" in pc.viz.data_vars + assert "col" in pc.viz.data_vars + assert "Groups" in pc.viz["mu"].coords + assert "BF_type" in pc.aes + assert "BF10" in pc.aes["BF_type"].values[0] + + def test_plot_energy_dist(self, datatree, backend): + pc = plot_energy(datatree, backend=backend) + assert pc is not None + assert hasattr(pc, "viz") + assert "/energy_" in pc.viz.groups + assert "kde" in pc.viz["/energy_"] + assert "energy" in pc.viz["/energy_"].coords + kde_values = pc.viz["/energy_"]["kde"].values + assert kde_values.size > 0 + assert "component_group" not in pc.viz["/energy_"]["kde"].dims + assert "alpha" not in pc.viz["/energy_"]["kde"].dims + energy_coords = pc.viz["/energy_"]["kde"].coords["energy"].values + assert "marginal" in energy_coords + assert "transition" in energy_coords + + def test_plot_energy_dist_sample(self, datatree_sample, backend): + pc = plot_energy(datatree_sample, backend=backend) + assert pc is not None + assert hasattr(pc, "viz") + assert "/energy_" in pc.viz.groups + assert "kde" in pc.viz["/energy_"] + assert "energy" in pc.viz["/energy_"].coords + kde_values = pc.viz["/energy_"]["kde"].values + assert kde_values.size > 0 + assert "component_group" not in pc.viz["/energy_"]["kde"].dims + assert "alpha" not in pc.viz["/energy_"]["kde"].dims + energy_coords = pc.viz["/energy_"]["kde"].coords["energy"].values + assert "marginal" in energy_coords + assert "transition" in energy_coords