Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Data Generation & Plot Tests (#174) #176

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_compare
plot_convergence_dist
plot_dist
plot_energy
plot_ess
plot_ess_evolution
plot_forest
Expand Down
10 changes: 9 additions & 1 deletion src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,5 +517,13 @@ def remove_axis(target, axis="y"):


def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
"""Interface to bokeh for setting the y scale of a plot."""
target.set_yscale(scale)


def grid(target, axis, color):
"""Interface to bokeh for setting a grid in any axis."""
if axis in ["y", "both"]:
target.ygrid.grid_line_color = color
if axis in ["x", "both"]:
target.xgrid.grid_line_color = color
5 changes: 5 additions & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,3 +497,8 @@ def remove_axis(target, axis="y"):
def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
target.set_yscale(scale)


def grid(target, axis, color):
"""Interface to matplotlib for setting a grid in any axis."""
target.grid(axis=axis, color=color)
9 changes: 9 additions & 0 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,12 @@ def remove_axis(target, axis="y"):
target.update_yaxes(visible=False)
if axis in ("x", "both"):
target.update_xaxes(visible=False)


def grid(target, axis="both", color=unset, **artist_kws):
"""Interface to plotly for setting a grid in any axis."""
kwargs = {"showgrid": True, "gridcolor": color}
if axis in ["y", "both"]:
target.update_yaxes(_filter_kwargs(kwargs, artist_kws))
if axis in ["x", "both"]:
target.update_xaxes(_filter_kwargs(kwargs, artist_kws))
17 changes: 17 additions & 0 deletions src/arviz_plots/plots/ppcrootogramplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
from arviz_plots.visuals import (
ci_line_y,
grid,
labelled_title,
labelled_x,
labelled_y,
Expand Down Expand Up @@ -84,6 +85,7 @@ def plot_ppc_rootogram(
* ci -> passed to :func:`~arviz_plots.visuals.ci_line_y`
* xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
* ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y`
* grid -> passed to :func:`~arviz_plots.visuals.grid`
* title -> passed to :func:`~arviz_plots.visuals.labelled_title`

pc_kwargs : mapping
Expand Down Expand Up @@ -263,6 +265,21 @@ def plot_ppc_rootogram(
**observed_ms_kwargs,
)

## grid
grid_kwargs = copy(plot_kwargs.get("grid", {}))

if grid_kwargs is not False:
_, _, grid_ignore = filter_aes(plot_collection, aes_map, "grid", sample_dims)
grid_kwargs.setdefault("color", "#cccccc")
grid_kwargs.setdefault("axis", "y")

plot_collection.map(
grid,
"grid",
ignore_aes=grid_ignore,
**grid_kwargs,
)

# set xlabel
_, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
Expand Down
8 changes: 7 additions & 1 deletion src/arviz_plots/visuals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,12 @@ def set_xticks(da, target, backend, values, labels, **kwargs):


def set_y_scale(da, target, backend, scale, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
"""Set scale for y-axis."""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
plot_backend.set_y_scale(target, scale, **kwargs)


def grid(da, target, backend, **kwargs):
"""Dispatch to ``remove_axis`` function in backend."""
plot_backend = import_module(f"arviz_plots.backend.{backend}")
plot_backend.grid(target, **kwargs)
49 changes: 48 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from scipy.stats import halfnorm, norm

from arviz_plots import (
plot_bf,
plot_compare,
plot_dist,
plot_energy,
plot_ess,
plot_ess_evolution,
plot_forest,
Expand Down Expand Up @@ -46,6 +48,10 @@ def generate_base_data(seed=31):
prior_predictive = rng.normal(size=(1, 100, 7))
posterior_predictive = rng.normal(size=(4, 100, 7))
diverging = rng.choice([True, False], size=(4, 100), p=[0.1, 0.9])
mu_prior_sampled = rng.normal(size=(1, 500))
tau_prior_sampled = np.exp(rng.normal(size=(1, 500)))
theta_prior_sampled = rng.normal(size=(1, 500, 8))
energy = rng.normal(loc=50, scale=10, size=(4, 100))

return {
"posterior": {"mu": mu, "theta": theta, "tau": tau},
Expand All @@ -54,7 +60,8 @@ def generate_base_data(seed=31):
"log_prior": {"mu": mu_prior, "theta": theta_prior, "tau": tau_prior},
"prior_predictive": {"y": prior_predictive},
"posterior_predictive": {"y": posterior_predictive},
"sample_stats": {"diverging": diverging},
"sample_stats": {"diverging": diverging, "energy": energy},
"prior": {"mu": mu_prior_sampled, "theta": theta_prior_sampled, "tau": tau_prior_sampled},
}


Expand Down Expand Up @@ -461,3 +468,43 @@ def test_plot_psense_dist_sample(self, datatree_sample, backend):
assert "component_group" in pc.viz["mu"]["credible_interval"].dims
assert "alpha" in pc.viz["mu"]["credible_interval"].dims
assert "hierarchy" in pc.viz["theta"].dims

def test_plot_bf(self, datatree, backend):
pc = plot_bf(datatree, var_name="mu", backend=backend)
assert "chart" in pc.viz.data_vars
assert "plot" in pc.viz.data_vars
assert "row" in pc.viz.data_vars
assert "col" in pc.viz.data_vars
assert "Groups" in pc.viz["mu"].coords
assert "BF_type" in pc.aes
assert "BF10" in pc.aes["BF_type"].values[0]

def test_plot_energy_dist(self, datatree, backend):
pc = plot_energy(datatree, backend=backend)
assert pc is not None
assert hasattr(pc, "viz")
assert "/energy_" in pc.viz.groups
assert "kde" in pc.viz["/energy_"]
assert "energy" in pc.viz["/energy_"].coords
kde_values = pc.viz["/energy_"]["kde"].values
assert kde_values.size > 0
assert "component_group" not in pc.viz["/energy_"]["kde"].dims
assert "alpha" not in pc.viz["/energy_"]["kde"].dims
energy_coords = pc.viz["/energy_"]["kde"].coords["energy"].values
assert "marginal" in energy_coords
assert "transition" in energy_coords

def test_plot_energy_dist_sample(self, datatree_sample, backend):
pc = plot_energy(datatree_sample, backend=backend)
assert pc is not None
assert hasattr(pc, "viz")
assert "/energy_" in pc.viz.groups
assert "kde" in pc.viz["/energy_"]
assert "energy" in pc.viz["/energy_"].coords
kde_values = pc.viz["/energy_"]["kde"].values
assert kde_values.size > 0
assert "component_group" not in pc.viz["/energy_"]["kde"].dims
assert "alpha" not in pc.viz["/energy_"]["kde"].dims
energy_coords = pc.viz["/energy_"]["kde"].coords["energy"].values
assert "marginal" in energy_coords
assert "transition" in energy_coords