Skip to content

Commit 7f9a025

Browse files
committed
add plot_forest draft
For everything to work, the following features were added to: * support for "__variable__" in aesthetics * default values for common aesthetics * support for dict of datatree input -> multiple models aligned within the same plot
1 parent 5a0cc94 commit 7f9a025

File tree

10 files changed

+469
-41
lines changed

10 files changed

+469
-41
lines changed

docs/source/api/plots.rst

+1
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ A complementary introduction and guide to ``plot_...`` functions is available at
1616
:toctree: generated/
1717

1818
plot_dist
19+
plot_forest
1920
plot_trace
2021
plot_trace_dist

src/arviz_plots/backend/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
to any type of the plotting backend or even custom objects, but all instances
1313
of the same placeholder must use the same type (whatever that is).
1414
"""
15+
import numpy as np
1516

1617
error = NotImplementedError(
1718
"The `arviz_plots.backend` module itself is for reference only. "
@@ -20,6 +21,22 @@
2021
)
2122

2223

24+
# generation of default values for aesthetics
25+
def get_default_aes(aes_key, n, kwargs):
26+
"""Generate `n` default values for a given aesthetics keyword."""
27+
if aes_key not in kwargs:
28+
if aes_key in {"x", "y"}:
29+
return np.arange(n)
30+
if aes_key == "alpha":
31+
return np.linspace(0.2, 0.7, n)
32+
return [None] * n
33+
aes_vals = kwargs[aes_key]
34+
n_aes_vals = len(aes_vals)
35+
if n_aes_vals >= n:
36+
return aes_vals[:n]
37+
return np.tile(aes_vals, (n // n_aes_vals) + 1)[:n]
38+
39+
2340
# object creation and i/o
2441
def show(chart):
2542
"""Show this :term:`chart`.

src/arviz_plots/backend/bokeh/__init__.py

+22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from bokeh.plotting import figure
88
from bokeh.plotting import show as _show
99

10+
from .. import get_default_aes as get_agnostic_default_aes
1011
from .legend import legend
1112

1213
__all__ = [
@@ -31,6 +32,27 @@ class UnsetDefault:
3132
unset = UnsetDefault()
3233

3334

35+
# generation of default values for aesthetics
36+
def get_default_aes(aes_key, n, kwargs):
37+
"""Generate `n` *bokeh valid* default values for a given aesthetics keyword."""
38+
if aes_key not in kwargs:
39+
if "color" in aes_key:
40+
# fmt: off
41+
vals = [
42+
'#3f90da', '#ffa90e', '#bd1f01', '#94a4a2', '#832db6',
43+
'#a96b59', '#e76300', '#b9ac70', '#717581', '#92dadd'
44+
]
45+
# fmt: on
46+
elif aes_key in {"linestyle", "line_dash"}:
47+
vals = ["solid", "dashed", "dotted", "dashdot"]
48+
elif aes_key == "marker":
49+
vals = ["circle", "cross", "triangle", "x", "diamond"]
50+
else:
51+
return get_agnostic_default_aes(aes_key, n, {})
52+
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
53+
return get_agnostic_default_aes(aes_key, n, kwargs)
54+
55+
3456
# object creation and i/o
3557
def show(chart):
3658
"""Show the provided bokeh layout."""

src/arviz_plots/backend/matplotlib/__init__.py

+29
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from matplotlib.cbook import normalize_kwargs
1212
from matplotlib.collections import PathCollection
1313
from matplotlib.lines import Line2D
14+
from matplotlib.pyplot import rcParams
1415
from matplotlib.pyplot import show as _show
1516
from matplotlib.pyplot import subplots
1617
from matplotlib.text import Text
1718

19+
from .. import get_default_aes as get_agnostic_default_aes
1820
from .legend import legend
1921

2022
__all__ = [
@@ -39,6 +41,33 @@ class UnsetDefault:
3941
unset = UnsetDefault()
4042

4143

44+
# generation of default values for aesthetics
45+
def get_default_aes(aes_key, n, kwargs):
46+
"""Generate `n` *bokeh valid* default values for a given aesthetics keyword."""
47+
if aes_key not in kwargs:
48+
default_prop_cycle = rcParams["axes.prop_cycle"].by_key()
49+
if ("color" in aes_key) or aes_key == "c":
50+
# fmt: off
51+
vals = [
52+
'#3f90da', '#ffa90e', '#bd1f01', '#94a4a2', '#832db6',
53+
'#a96b59', '#e76300', '#b9ac70', '#717581', '#92dadd'
54+
]
55+
# fmt: on
56+
vals = default_prop_cycle.get("color", vals)
57+
elif aes_key in {"linestyle", "ls"}:
58+
vals = ["-", "--", ":", "-."]
59+
vals = default_prop_cycle.get("linestyle", vals)
60+
elif aes_key in {"marker", "m"}:
61+
vals = ["o", "+", "^", "x", "d"]
62+
vals = default_prop_cycle.get("marker", vals)
63+
elif aes_key in default_prop_cycle:
64+
vals = default_prop_cycle[aes_key]
65+
else:
66+
return get_agnostic_default_aes(aes_key, n, {})
67+
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
68+
return get_agnostic_default_aes(aes_key, n, kwargs)
69+
70+
4271
# object creation and i/o
4372
def show(chart): # pylint: disable=unused-argument
4473
"""Show all existing matplotlib figures."""

src/arviz_plots/plot_collection.py

+78-28
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
from datatree import DataTree
99

1010

11+
def concat_model_dict(data):
12+
"""Merge multiple Datasets into a single one along a new model dimension."""
13+
if isinstance(data, dict):
14+
ds_list = data.values()
15+
if not all(isinstance(ds, xr.Dataset) for ds in ds_list):
16+
raise TypeError("Provided data must be a Dataset or dictionary of Datasets")
17+
data = xr.concat(ds_list, dim="model").assign_coords(model=list(data))
18+
return data
19+
20+
1121
def sel_subset(sel, present_dims):
1222
"""Subset a dictionary of dim: coord values.
1323
@@ -97,8 +107,15 @@ def _get_aes_dict_from_dt(aes_dt):
97107
an aes DataTree directly when initializating a PlotCollection object.
98108
This method is used to generate the more basic dictionary from the DataTree.
99109
"""
110+
child_list = list(aes_dt.children.values())
100111
aes = {}
101-
for ds in aes_dt.children.values():
112+
aes_in_all_vars = set.intersection(*[set(child.data_vars) for child in child_list])
113+
aes = {
114+
aes_key: ["__variable__"]
115+
for aes_key in aes_in_all_vars
116+
if any(child[aes_key].item(0) != child_list[0][aes_key].item(0) for child in child_list)
117+
}
118+
for ds in child_list:
102119
for aes_key, values in ds.items():
103120
if not values.dims:
104121
continue
@@ -211,7 +228,7 @@ def data(self):
211228
@data.setter
212229
def data(self, value):
213230
# might want/be possible to make some checks on the data before setting it
214-
self._data = value
231+
self._data = concat_model_dict(value)
215232

216233
@property
217234
def aes_set(self):
@@ -301,31 +318,66 @@ def generate_aes_dt(self, aes=None, **kwargs):
301318
but it will always be possible to set their value manually.
302319
"""
303320
if aes is None:
304-
aes = {}
321+
aes = self._aes
322+
kwargs = self._kwargs
305323
self._aes = aes
306324
self._kwargs = kwargs
307-
self._aes_dt = DataTree()
308-
for var_name, da in self.data.items():
309-
ds = xr.Dataset()
310-
for aes_key, dims in aes.items():
311-
aes_vals = kwargs.get(aes_key, [None])
312-
aes_dims = [dim for dim in dims if dim in da.dims]
313-
aes_raw_shape = [da.sizes[dim] for dim in aes_dims]
314-
if not aes_raw_shape:
315-
ds[aes_key] = aes_vals[0]
316-
continue
317-
n_aes = np.prod(aes_raw_shape)
318-
n_aes_vals = len(aes_vals)
319-
if n_aes_vals > n_aes:
320-
aes_vals = aes_vals[:n_aes]
321-
elif n_aes_vals < n_aes:
322-
aes_vals = np.tile(aes_vals, (n_aes // n_aes_vals) + 1)[:n_aes]
323-
ds[aes_key] = xr.DataArray(
324-
np.array(aes_vals).reshape(aes_raw_shape),
325-
dims=aes_dims,
326-
coords={dim: da.coords[dim] for dim in dims if dim in da.coords},
325+
if not hasattr(self, "backend"):
326+
plot_bknd = import_module(".backend", package="arviz_plots")
327+
else:
328+
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
329+
get_default_aes = plot_bknd.get_default_aes
330+
ds_dict = {var_name: xr.Dataset() for var_name in self.data.data_vars}
331+
for aes_key, dims in aes.items():
332+
if "__variable__" in dims:
333+
total_aes_vals = int(
334+
np.sum(
335+
[
336+
np.prod([size for dim, size in da.sizes.items() if dim in dims])
337+
for da in self.data.values()
338+
]
339+
)
340+
)
341+
aes_vals = get_default_aes(aes_key, total_aes_vals, kwargs)
342+
aes_cumulative = 0
343+
for var_name, da in self.data.items():
344+
ds = ds_dict[var_name]
345+
aes_dims = [dim for dim in dims if dim in da.dims]
346+
aes_raw_shape = [da.sizes[dim] for dim in aes_dims]
347+
if not aes_raw_shape:
348+
ds[aes_key] = np.asarray(aes_vals)[
349+
aes_cumulative : aes_cumulative + 1
350+
].squeeze()
351+
aes_cumulative += 1
352+
continue
353+
n_aes = np.prod(aes_raw_shape)
354+
ds[aes_key] = xr.DataArray(
355+
np.array(aes_vals[aes_cumulative : aes_cumulative + n_aes]).reshape(
356+
aes_raw_shape
357+
),
358+
dims=aes_dims,
359+
coords={dim: da.coords[dim] for dim in dims if dim in da.coords},
360+
)
361+
aes_cumulative += n_aes
362+
else:
363+
total_aes_vals = int(
364+
np.prod([self.data.sizes[dim] for dim in self.data.dims if dim in dims])
327365
)
328-
DataTree(name=var_name, parent=self._aes_dt, data=ds)
366+
aes_vals = get_default_aes(aes_key, total_aes_vals, kwargs)
367+
for var_name, da in self.data.items():
368+
ds = ds_dict[var_name]
369+
aes_dims = [dim for dim in dims if dim in da.dims]
370+
aes_raw_shape = [da.sizes[dim] for dim in aes_dims]
371+
if not aes_raw_shape:
372+
ds[aes_key] = aes_vals[0]
373+
continue
374+
n_aes = np.prod(aes_raw_shape)
375+
ds[aes_key] = xr.DataArray(
376+
np.array(aes_vals[:n_aes]).reshape(aes_raw_shape),
377+
dims=aes_dims,
378+
coords={dim: da.coords[dim] for dim in dims if dim in da.coords},
379+
)
380+
self._aes_dt = DataTree.from_dict(ds_dict)
329381

330382
@property
331383
def base_loop_dims(self):
@@ -385,8 +437,7 @@ def wrap(
385437
plot_grid_kws = {}
386438
if backend is None:
387439
backend = rcParams["plot.backend"]
388-
if isinstance(data, dict):
389-
data = xr.concat(data.values(), dim="model").assign_coords(model=list(data))
440+
data = concat_model_dict(data)
390441

391442
n_plots, plots_per_var = _process_facet_dims(data, cols)
392443
if n_plots <= col_wrap:
@@ -501,8 +552,7 @@ def grid(
501552
repeated_dims = [col for col in cols if col in rows]
502553
if repeated_dims:
503554
raise ValueError("The same dimension can't be used for both cols and rows.")
504-
if isinstance(data, dict):
505-
data = xr.concat(data.values(), dim="model").assign_coords(model=list(data))
555+
data = concat_model_dict(data)
506556

507557
n_cols, cols_per_var = _process_facet_dims(data, cols)
508558
n_rows, rows_per_var = _process_facet_dims(data, rows)

src/arviz_plots/plots/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Batteries-included ArviZ plots."""
22

33
from .distplot import plot_dist
4+
from .forestplot import plot_forest
45
from .tracedistplot import plot_trace_dist
56
from .traceplot import plot_trace
67

7-
__all__ = ["plot_dist", "plot_trace", "plot_trace_dist"]
8+
__all__ = ["plot_dist", "plot_forest", "plot_trace", "plot_trace_dist"]

src/arviz_plots/plots/distplot.py

+37-11
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
import xarray as xr
44
from arviz_base import rcParams
55
from arviz_base.labels import BaseLabeller
6-
from arviz_base.utils import _var_names
76

87
from arviz_plots.plot_collection import PlotCollection
9-
from arviz_plots.plots.utils import filter_aes
8+
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
109
from arviz_plots.visuals import (
1110
ecdf_line,
1211
labelled_title,
@@ -23,6 +22,7 @@ def plot_dist(
2322
var_names=None,
2423
filter_vars=None,
2524
group="posterior",
25+
coords=None,
2626
sample_dims=None,
2727
kind=None,
2828
point_estimate=None,
@@ -46,8 +46,9 @@ def plot_dist(
4646
4747
Parameters
4848
----------
49-
dt : DataTree
50-
Input data
49+
dt : DataTree or dict of {str : DataTree}
50+
Input data. In case of dictionary input, the keys are taken to be model names.
51+
In such cases, a dimension "model" is generated and can be used to map to aesthetics.
5152
var_names: str or list of str, optional
5253
One or more variables to be plotted.
5354
Prefix the variables by ~ when you want to exclude them from the plot.
@@ -57,6 +58,7 @@ def plot_dist(
5758
If “regex”, interpret var_names as regular expressions on the real variables names.
5859
group : str, default "posterior"
5960
Group to be plotted.
61+
coords : dict, optional
6062
sample_dims : iterable, optional
6163
Dimensions to reduce unless mapped to an aesthetic.
6264
Defaults to ``rcParams["data.sample_dims"]``
@@ -106,6 +108,20 @@ def plot_dist(
106108
Returns
107109
-------
108110
PlotCollection
111+
112+
Examples
113+
--------
114+
Default plot_dist for a single model:
115+
116+
.. plot::
117+
:context: close-figs
118+
119+
>>> from arviz_plots import plot_dist
120+
>>> from arviz_base import load_arviz_data
121+
>>> centered = load_arviz_data('centered_eight')
122+
>>> non_centered = load_arviz_data('non_centered_eight')
123+
>>> pc = plot_dist(centered)
124+
109125
"""
110126
if ci_kind not in ["hdi", "eti", None]:
111127
raise ValueError("ci_kind must be either 'hdi' or 'eti'")
@@ -130,27 +146,37 @@ def plot_dist(
130146
if stats_kwargs is None:
131147
stats_kwargs = {}
132148

133-
distribution = dt[group].ds
134-
var_names = _var_names(var_names, distribution, filter_vars)
135-
136-
if var_names is not None:
137-
distribution = dt[group].ds[var_names]
149+
distribution = process_group_variables_coords(
150+
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
151+
)
138152

139153
if plot_collection is None:
140154
if backend is None:
141155
backend = rcParams["plot.backend"]
142156
pc_kwargs.setdefault("col_wrap", 5)
143157
pc_kwargs.setdefault(
144-
"cols", ["__variable__"] + [dim for dim in distribution.dims if dim not in sample_dims]
158+
"cols",
159+
["__variable__"]
160+
+ [dim for dim in distribution.dims if (dim not in sample_dims) and (dim != "model")],
145161
)
162+
if "model" in distribution:
163+
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
164+
pc_kwargs["aes"].setdefault("color", ["model"])
146165
plot_collection = PlotCollection.wrap(
147166
distribution,
148167
backend=backend,
149168
**pc_kwargs,
150169
)
151170

152171
if aes_map is None:
153-
aes_map = {kind: plot_collection.aes_set}
172+
if "model" in distribution:
173+
aes_map = {
174+
kind: plot_collection.aes_set,
175+
"credible_interval": ["color"],
176+
"point_estimate": ["color"],
177+
}
178+
else:
179+
aes_map = {kind: plot_collection.aes_set}
154180
if "point_estimate" in aes_map and "point_estimate_text" not in aes_map:
155181
aes_map["point_estimate_text"] = aes_map["point_estimate"]
156182
if labeller is None:

0 commit comments

Comments
 (0)