Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented y normalization and updated aesthetics #160

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions src/arviz_plots/plots/distplot.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,14 @@ def plot_dist(
plot_kwargs=None,
stats_kwargs=None,
pc_kwargs=None,
yrelative=None,
):
"""
Parameters:
- dt: Data to be plotted
- yrelative: Optional, relative y-values that will be scaled
"""

"""Plot 1D marginal densities in the style of John K. Kruschke’s book.
Generate :term:`faceted` :term:`plots` with: a graphical representation of 1D marginal
@@ -171,9 +178,13 @@ def plot_dist(
pc_kwargs = {}
else:
pc_kwargs = pc_kwargs.copy()

if stats_kwargs is None:
stats_kwargs = {}
if yrelative is not None:
density_max = dt.max()
y = yrelative * density_max
else:
y = None

distribution = process_group_variables_coords(
dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords
@@ -185,6 +196,10 @@ def plot_dist(
backend = plot_collection.backend
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")

return plot_bknd.plot_dist(
dt, y=y, var_names=var_names, filter_vars=filter_vars, coords=coords, **plot_kwargs
)

if plot_collection is None:
if backend is None:
backend = rcParams["plot.backend"]
@@ -396,9 +411,15 @@ def plot_dist(
labeller=labeller,
**title_kwargs,
)
if yrelative is not None:
plot_collection.map(
lambda y: y * dt.max(), # Scale yrelative values
"y",
)

if plot_kwargs.get("remove_axis", True) is not False:
plot_collection.map(
remove_axis, store_artist=False, axis="y", ignore_aes=plot_collection.aes_set
plot_collection.map(
remove_axis, store_artist=False, axis="y", ignore_aes=plot_collection.aes_set
)

return plot_collection
8 changes: 8 additions & 0 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
@@ -186,6 +186,14 @@ def test_plot_trace_sample(self, datatree_sample, backend):
assert "plot" not in pc.viz.data_vars
assert pc.viz["mu"].trace.shape == ()

def test_plot_dist_yrelative():
data = np.array([0.1, 0.3, 0.5, 0.7])
yrelative = np.array([0.2, 0.4, 0.6, 0.8])
y_expected = yrelative * data.max()
y_computed = process_aesthetics(data, yrelative=yrelative)
assert np.allclose(y_computed, y_expected), "yrelative scaling failed!"


@pytest.mark.parametrize("compact", (True, False))
@pytest.mark.parametrize("combined", (True, False))
def test_plot_trace_dist(self, datatree, backend, compact, combined):