Skip to content

Commit 31fdf6a

Browse files
committed
add classmethod tests
1 parent 6948668 commit 31fdf6a

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

tests/test_plot_collection.py

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# pylint: disable=no-self-use, redefined-outer-name
2+
"""Test PlotCollection."""
3+
import numpy as np
4+
import pytest
5+
from arviz_base import dict_to_dataset
6+
7+
from arviz_plots import PlotCollection
8+
9+
10+
@pytest.fixture(scope="module")
11+
def dataset(seed=31):
12+
rng = np.random.default_rng(seed)
13+
mu = rng.normal(size=(3, 10))
14+
theta = rng.normal(size=(3, 10, 7))
15+
eta = rng.normal(size=(3, 10, 4, 7))
16+
17+
return dict_to_dataset(
18+
{"mu": mu, "theta": theta, "eta": eta},
19+
dims={"theta": ["hierarchy"], "eta": ["group", "hierarchy"]},
20+
)
21+
22+
23+
@pytest.mark.parametrize("backend", ["matplotlib", "bokeh"])
24+
class TestFacetting:
25+
def test_wrap(self, dataset, backend):
26+
pc = PlotCollection.wrap(
27+
dataset[["theta", "eta"]], backend=backend, cols=["hierarchy"], col_wrap=4
28+
)
29+
assert "plot" in pc.viz.data_vars
30+
assert pc.viz["plot"].shape == (7,)
31+
assert pc.viz["row"].max() == 1
32+
assert pc.viz["col"].max() == 3
33+
34+
def test_wrap_variable(self, dataset, backend):
35+
pc = PlotCollection.wrap(dataset, backend=backend, cols=["__variable__", "group"])
36+
assert "plot" not in pc.viz.data_vars
37+
assert all(f"/{var_name}" in pc.viz.groups for var_name in ("mu", "theta", "eta"))
38+
assert all("plot" in pc.viz[var_name].data_vars for var_name in ("mu", "theta", "eta"))
39+
assert pc.viz["mu"]["plot"].size == 1
40+
assert pc.viz["theta"]["plot"].size == 1
41+
assert pc.viz["eta"]["plot"].size == 4
42+
43+
def test_wrap_only_variable(self, dataset, backend):
44+
pc = PlotCollection.wrap(dataset, backend=backend, cols=["__variable__"])
45+
assert "plot" not in pc.viz.data_vars
46+
assert all(f"/{var_name}" in pc.viz.groups for var_name in ("mu", "theta", "eta"))
47+
assert all("plot" in pc.viz[var_name].data_vars for var_name in ("mu", "theta", "eta"))
48+
assert pc.viz["mu"]["plot"].size == 1
49+
assert pc.viz["theta"]["plot"].size == 1
50+
assert pc.viz["eta"]["plot"].size == 1
51+
52+
def test_grid(self, dataset, backend):
53+
pc = PlotCollection.grid(
54+
dataset[["theta", "eta"]], backend=backend, cols=["chain"], rows=["hierarchy"]
55+
)
56+
assert "plot" in pc.viz.data_vars
57+
assert not pc.viz.children
58+
assert "chain" in pc.viz["plot"].dims
59+
assert pc.viz["plot"].sizes["chain"] == 3
60+
assert "hierarchy" in pc.viz["plot"].dims
61+
assert pc.viz["plot"].sizes["hierarchy"] == 7
62+
assert "group" not in pc.viz["plot"].dims
63+
64+
def test_grid_scalar(self, dataset, backend):
65+
pc = PlotCollection.grid(dataset, backend=backend)
66+
assert "plot" in pc.viz.data_vars
67+
assert not pc.viz.children
68+
assert pc.viz["plot"].size == 1
69+
70+
@pytest.mark.parametrize("axis", ["rows", "cols"])
71+
def test_grid_rows_cols(self, dataset, backend, axis):
72+
pc = PlotCollection.grid(dataset[["theta", "eta"]], backend=backend, **{axis: ["chain"]})
73+
assert "plot" in pc.viz.data_vars
74+
assert not pc.viz.children
75+
assert "chain" in pc.viz["plot"].dims
76+
assert pc.viz["plot"].sizes["chain"] == 3
77+
assert "hierarchy" not in pc.viz["plot"].dims
78+
assert "group" not in pc.viz["plot"].dims
79+
assert pc.viz["row" if axis == "cols" else "col"].max() == 0
80+
assert pc.viz[axis[:3]].max() == 2
81+
82+
def test_grid_variable(self, dataset, backend):
83+
pc = PlotCollection.grid(
84+
dataset[["theta", "eta"]], backend=backend, cols=["hierarchy"], rows=["__variable__"]
85+
)
86+
assert "plot" not in pc.viz.data_vars
87+
assert all(f"/{var_name}" in pc.viz.groups for var_name in ("theta", "eta"))
88+
assert all("plot" in pc.viz[var_name].data_vars for var_name in ("theta", "eta"))
89+
90+
91+
class TestAesthetics:
92+
def test_1d_aes(self, dataset):
93+
pc = PlotCollection.grid(dataset)
94+
assert pc
95+
96+
97+
class TestMap:
98+
def test_map(self, dataset):
99+
pc = PlotCollection.grid(dataset)
100+
assert pc

0 commit comments

Comments
 (0)