Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions pymc_marketing/mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@
"""

import itertools
from collections import namedtuple
from collections.abc import Iterable

import arviz as az
Expand Down Expand Up @@ -243,6 +244,83 @@ def _build_subplot_title(
return ", ".join(title_parts)
return fallback_title

def _align_y_axes(self, ax_left, ax_right) -> None:
"""Align y=0 of primary and secondary y-axis."""
# Store limits of both axes in named tuples.
YLimits = namedtuple("YLimits", ["bottom", "top"])
ylims_left = YLimits(*ax_left.axes.get_ylim())
ylims_right = YLimits(*ax_right.axes.get_ylim())

# Calculate the relative position of zero on both axes.
# 0 means all values are positive (zero is at the bottom)
# 1 means all values are negative (zero is at the top)
zero_rel_pos_left = -ylims_left.bottom / (ylims_left.top - ylims_left.bottom)
zero_rel_pos_right = -ylims_right.bottom / (
ylims_right.top - ylims_right.bottom
)

# If relative positions are equal, no action needed
if zero_rel_pos_left == zero_rel_pos_right:
return

# If both axes include mixed values, edit one to match the other by solving
# rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.
if zero_rel_pos_left not in [0, 1] and zero_rel_pos_right not in [0, 1]:
if zero_rel_pos_left < zero_rel_pos_right:
ax_left.set_ylim(
bottom=-zero_rel_pos_right
* ylims_left.top
/ (1 - zero_rel_pos_right)
)
else:
ax_right.set_ylim(
bottom=-zero_rel_pos_left
* ylims_right.top
/ (1 - zero_rel_pos_left)
)

# If one relative position is 1, edit the top by solving
# rel_pos_other = (0 - bottom) / (new_top - bottom) for new_top.
if zero_rel_pos_left == 1:
# Left axis is all negative, right axis has positive values
# if other axis is fully positive, place y=0 at the center of this axis.
ax_left.set_ylim(
top=ylims_left.bottom * (1 - 1 / zero_rel_pos_right)
if zero_rel_pos_right
else -ylims_left.bottom
)
# Update lims and zero_rel_pos in case we need to edit the other axis.
ylims_left = YLimits(*ax_left.axes.get_ylim())
zero_rel_pos_left = -ylims_left.bottom / (
ylims_left.top - ylims_left.bottom
)
elif zero_rel_pos_right == 1:
# Right axis is all negative, left axis has positive values
# if other axis is fully positive, place y=0 at the center of this axis.
ax_right.set_ylim(
top=ylims_right.bottom * (1 - 1 / zero_rel_pos_left)
if zero_rel_pos_left
else -ylims_right.bottom
)
# Update lims and zero_rel_pos in case we need to edit the other axis.
ylims_right = YLimits(*ax_right.axes.get_ylim())
zero_rel_pos_right = -ylims_right.bottom / (
ylims_right.top - ylims_right.bottom
)

# If one relative position is 0, edit bottom by solving
# rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom.
if zero_rel_pos_left == 0:
# Left axis is all positive, right axis has negative values
ax_left.set_ylim(
bottom=-zero_rel_pos_right * ylims_left.top / (1 - zero_rel_pos_right)
)
elif zero_rel_pos_right == 0:
# Right axis is all positive, left axis has negative values
ax_right.set_ylim(
bottom=-zero_rel_pos_left * ylims_right.top / (1 - zero_rel_pos_left)
)

def _get_additional_dim_combinations(
self,
data: xr.Dataset,
Expand Down Expand Up @@ -1267,6 +1345,9 @@ def _plot_budget_allocation_bars(
ax.set_xticklabels(channels)
ax.tick_params(axis="x", rotation=90)

# Ensure that y=0 are aligned between ax and ax2.
self._align_y_axes(ax, ax2)

# Turn off grid and add legend
ax.grid(False)
ax2.grid(False)
Expand Down
161 changes: 161 additions & 0 deletions tests/mmm/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -1023,3 +1024,163 @@ def test__dim_list_handler_mixed():
keys, combos = suite._dim_list_handler({"country": ["A", "B"], "region": "X"})
assert keys == ["country"]
assert set(combos) == {("A",), ("B",)}


class TestAlignYAxes:
"""Test the _align_y_axes method with various axis configurations."""

@pytest.fixture
def mock_axes(self):
"""Create mock axes for testing."""
_, ax = plt.subplots()
ax2 = ax.twinx()
return ax, ax2

@pytest.mark.parametrize(
"ax1_limits,ax2_limits,expected_ax1_limits,expected_ax2_limits,description",
[
((0, 10), (0, 20), (0, 10), (0, 20), "both_positive"),
((-10, 0), (-20, 0), (-10, 0), (-20, 0), "both_negative"),
((-5, 10), (0, 20), (-5, 10), (-10, 20), "ax1_mixed_ax2_positive"),
((-10, 0), (-5, 20), (-10, 40), (-5, 20), "ax1_negative_ax2_mixed"),
((-2, 10), (-5, 20), (-2.5, 10), (-5, 20), "both_mixed_ax1_ratio_smaller"),
((-5, 10), (-2, 20), (-5, 10), (-10, 20), "both_mixed_ax2_ratio_smaller"),
((-10, 0), (0, 20), (-10, 10), (-20, 20), "ax1_fully_negative"),
((-2, 0), (0, 10), (-2, 2), (-10, 10), "ax2_fully_negative"),
],
)
def test_align_y_axes_various_scenarios(
self,
mock_axes,
ax1_limits,
ax2_limits,
expected_ax1_limits,
expected_ax2_limits,
description,
):
"""Test _align_y_axes with various axis limit scenarios."""
ax, ax2 = mock_axes
suite = MMMPlotSuite(az.InferenceData())

# Set initial limits
ax.set_ylim(ax1_limits)
ax2.set_ylim(ax2_limits)

# Call the method
suite._align_y_axes(ax, ax2)

# Check results
actual_ax1_limits = ax.get_ylim()
actual_ax2_limits = ax2.get_ylim()

assert actual_ax1_limits == pytest.approx(expected_ax1_limits, rel=1e-10), (
f"Axis 1 limits incorret in: {description}"
)
assert actual_ax2_limits == pytest.approx(expected_ax2_limits, rel=1e-10), (
f"Axis 2 limits incorret in: {description}"
)


class TestPlotBudgetAllocationBars:
"""Test the _plot_budget_allocation_bars method and its integration with _align_y_axes."""

@pytest.fixture
def mock_plot_suite(self):
"""Create a mock MMMPlotSuite for testing."""
return MMMPlotSuite(az.InferenceData())

@pytest.fixture
def sample_data(self):
"""Create sample data for budget allocation plotting."""
channels = np.array(["TV", "Radio", "Social Media", "Search", "Email"])
allocated_spend = np.array([5000, 3000, 2000, 4000, 1000])
channel_contribution = np.array([0.4, 0.2, 0.15, 0.2, 0.05])
return channels, allocated_spend, channel_contribution

def test_plot_budget_allocation_bars_basic(self, mock_plot_suite, sample_data):
"""Test basic functionality of _plot_budget_allocation_bars."""
channels, allocated_spend, channel_contribution = sample_data
_, ax = plt.subplots()

# Should not raise any errors
mock_plot_suite._plot_budget_allocation_bars(
ax, channels, allocated_spend, channel_contribution
)

# Verify the plot was created
assert len(ax.patches) > 0 # Should have bar patches
assert ax.get_xlabel() == "Channels"
assert ax.get_ylabel() == "Allocated Spend"

@pytest.mark.parametrize(
"allocated_spend,channel_contribution,description",
[
(
np.array([5000, 3000, 2000, 4000, 1000]),
np.array([0.4, 0.2, 0.15, 0.2, 0.05]),
"all_positive",
),
(
np.array([5000, 3000, 2000, 4000, -1000]),
np.array([0.4, 0.2, 0.15, 0.2, 0.05]),
"mixed_spend_positive_contribution",
),
(
np.array([-5000, -3000, -2000, -4000, -1000]),
np.array([-0.4, -0.2, -0.15, -0.2, -0.05]),
"all_negative",
),
(
np.array([5000, 3000, 2000, 4000, 1000]),
np.array([-0.4, -0.2, -0.15, -0.2, -0.05]),
"positive_spend_negative_contribution",
),
(
np.array([-5000, -3000, -2000, -4000, -1000]),
np.array([0.4, -0.2, -0.15, -0.2, -0.05]),
"mixed_both",
),
],
)
def test_plot_budget_allocation_bars_axis_alignment(
self, mock_plot_suite, allocated_spend, channel_contribution, description
):
"""Test that _plot_budget_allocation_bars properly aligns y-axes for different data scenarios."""
channels = np.array(["TV", "Radio", "Social Media", "Search", "Email"])
_, ax = plt.subplots()

# Call the method
mock_plot_suite._plot_budget_allocation_bars(
ax, channels, allocated_spend, channel_contribution
)

# Get the twin axis (created for channel contributions)
ax2 = ax.right_ax if hasattr(ax, "right_ax") else None
if ax2 is None:
# Find the twin axis manually
for other_ax in ax.figure.axes:
if (
other_ax != ax
and hasattr(other_ax, "yaxis")
and other_ax.yaxis.get_scale() == ax.yaxis.get_scale()
):
ax2 = other_ax
break

assert ax2 is not None, f"Twin axis not found for {description}"

# Verify that both axes exist and have been processed
assert ax.get_ylim() is not None
assert ax2.get_ylim() is not None

# The method should have called _align_y_axes, so we can verify
# that the axes are in a reasonable state (no NaN or infinite values)
ax1_limits = ax.get_ylim()
ax2_limits = ax2.get_ylim()

assert all(np.isfinite(ax1_limits)), (
f"ax1 limits not finite for {description}: {ax1_limits}"
)
assert all(np.isfinite(ax2_limits)), (
f"ax2 limits not finite for {description}: {ax2_limits}"
)