Skip to content

Commit 72af613

Browse files
committed
better support for aesthetics in labels elements
1 parent 280ca42 commit 72af613

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

src/arviz_plots/plots/forestplot.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ def plot_forest(
353353
)
354354

355355
# add labels and shading first, so forest plot is rendered on top
356-
_, lab_aes, lab_ignore = filter_aes(plot_collection, aes_map, "labels", sample_dims)
357356
cumulative_label = []
358357
x = 0
359358
for label in labellable_dims:
@@ -366,13 +365,6 @@ def plot_forest(
366365
shade_extend = 0.5 if add_factor == 0 else 0.3
367366
if label not in labels:
368367
continue
369-
lab_kwargs = plot_kwargs.get("labels", {}).copy()
370-
if "color" not in lab_aes:
371-
lab_kwargs.setdefault("color", "black")
372-
if x == 0:
373-
lab_kwargs.setdefault("horizontal_align", "left")
374-
if x == len(labels) - 1:
375-
lab_kwargs.setdefault("horizontal_align", "right")
376368
if label == "__variable__":
377369
y_max = y_ds.max() + shade_extend
378370
y_min = y_ds.min() - shade_extend
@@ -413,6 +405,23 @@ def plot_forest(
413405
ignore_aes=shade_ignore,
414406
**shade_kwargs,
415407
)
408+
_, lab_aes, lab_ignore = filter_aes(plot_collection, aes_map, "labels", sample_dims)
409+
extra_ignore_aes = []
410+
for aes_key in lab_aes:
411+
if aes_key == "overlay":
412+
continue
413+
aes_ds = plot_collection.get_aes_as_dataset(aes_key)
414+
if set(aes_ds.dims).difference(cumulative_label):
415+
extra_ignore_aes.append(aes_key)
416+
lab_aes = set(lab_aes).difference(extra_ignore_aes)
417+
lab_ignore = set(lab_ignore).union(extra_ignore_aes)
418+
lab_kwargs = plot_kwargs.get("labels", {}).copy()
419+
if "color" not in lab_aes:
420+
lab_kwargs.setdefault("color", "black")
421+
if x == 0:
422+
lab_kwargs.setdefault("horizontal_align", "left")
423+
if x == len(labels) - 1:
424+
lab_kwargs.setdefault("horizontal_align", "right")
416425
plot_collection.map(
417426
annotate_label,
418427
f"{label.strip('_')}_label",

tests/test_plots.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""Test batteries-included plots."""
33
import numpy as np
44
import pytest
5-
from arviz_base import from_dict
5+
from arviz_base import from_dict, load_arviz_data
66

77
from arviz_plots import plot_dist, plot_forest, plot_trace, visuals
88

@@ -120,13 +120,15 @@ def test_plot_forest_extendable(self, datatree, backend):
120120
assert pc.viz["plot"].sizes["column"] == 3
121121
assert all("ess" in child.data_vars for child in pc.viz.children.values())
122122

123-
def test_plot_forest_color_shading(self, datatree2, backend):
124-
pc = plot_forest(
125-
datatree2,
126-
pc_kwargs={"aes": {"color": ["__variable__"]}},
127-
aes_map={"labels": ["color"]},
128-
shade_label="hierarchy",
129-
backend=backend,
130-
)
131-
assert "plot" in pc.viz.data_vars
132-
assert all("shade" in child.data_vars for child in pc.viz.children.values())
123+
def test_plot_forest_aes_labels_shading(self, backend):
124+
post = load_arviz_data("rugby_field").posterior.ds.sel(draw=slice(None, 100))
125+
for pseudo_dim in ("__variable__", "field", "team"):
126+
pc = plot_forest(
127+
post,
128+
pc_kwargs={"aes": {"color": [pseudo_dim]}},
129+
aes_map={"labels": ["color"]},
130+
shade_label=pseudo_dim,
131+
backend=backend,
132+
)
133+
assert "plot" in pc.viz.data_vars
134+
assert all("shade" in child.data_vars for child in pc.viz.children.values())

0 commit comments

Comments
 (0)