From c1d1204de975908614b14dbbf14374e24a234e34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Tue, 30 Sep 2025 12:24:09 +0300 Subject: [PATCH 1/5] Ensure y=0 are aligned between the two axes --- pymc_marketing/mmm/plot.py | 39 +++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index b7a2ddae4..fd808ad73 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -243,6 +243,40 @@ def _build_subplot_title( return ", ".join(title_parts) return fallback_title + def _align_y_axes(self, ax, ax2, include_zero=False): + """Align y=0 of primary and secondary y-axis.""" + if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0: + ylims1 = ax.axes.get_ylim() + ylims2 = ax2.axes.get_ylim() + # Find the ratio of negative vs. positive part of the axes. + if ylims1[1]: + ax1_yratio = ylims1[0] / ylims1[1] + else: + # Fully negative axis. + ax1_yratio = -1 + + if ylims2[1]: + ax2_yratio = ylims2[0] / ylims2[1] + else: + # Fully negative axis, may need to reflect the other + ax2_yratio = -1 + + # Make axis adjustments. If both axes fully negative, no adjustment. + if ax1_yratio < ax2_yratio: + ax2.set_ylim(bottom = ylims2[1]*ax1_yratio) + if ax1_yratio == -1: + # if the axis is fully negative, center zero. + ax.set_ylim(top=-ylims1[0]) + elif ax2_yratio < ax1_yratio: + ax.set_ylim(bottom = ylims1[1]*ax2_yratio) + if ax2_yratio == -1: + # if the axis is fully negative, center zero. + ax2.set_ylim(top=-ylims2[0]) + elif include_zero: + # Ensure both axes start at zero + ax.set_ylim(bottom=0) + ax2.set_ylim(bottom=0) + def _get_additional_dim_combinations( self, data: xr.Dataset, @@ -1179,7 +1213,7 @@ def _plot_budget_allocation_bars( alpha=opacity, label="Channel Contribution", ) - + # Labels and formatting ax.set_xlabel("Channels") ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) @@ -1190,6 +1224,9 @@ def _plot_budget_allocation_bars( ax.set_xticklabels(channels) ax.tick_params(axis="x", rotation=90) + # Ensure that y=0 are aligned between ax and ax2. + self._align_y_axes(ax, ax2, include_zero=True) + # Turn off grid and add legend ax.grid(False) ax2.grid(False) From 6da79ecbd1e5c4af97cb94ced2b31ee71f505799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Tue, 30 Sep 2025 12:45:03 +0300 Subject: [PATCH 2/5] run pre-commit formatting --- pymc_marketing/mmm/plot.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index fd808ad73..e461b765a 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -260,15 +260,15 @@ def _align_y_axes(self, ax, ax2, include_zero=False): else: # Fully negative axis, may need to reflect the other ax2_yratio = -1 - + # Make axis adjustments. If both axes fully negative, no adjustment. if ax1_yratio < ax2_yratio: - ax2.set_ylim(bottom = ylims2[1]*ax1_yratio) + ax2.set_ylim(bottom=ylims2[1] * ax1_yratio) if ax1_yratio == -1: # if the axis is fully negative, center zero. ax.set_ylim(top=-ylims1[0]) elif ax2_yratio < ax1_yratio: - ax.set_ylim(bottom = ylims1[1]*ax2_yratio) + ax.set_ylim(bottom=ylims1[1] * ax2_yratio) if ax2_yratio == -1: # if the axis is fully negative, center zero. ax2.set_ylim(top=-ylims2[0]) @@ -276,7 +276,7 @@ def _align_y_axes(self, ax, ax2, include_zero=False): # Ensure both axes start at zero ax.set_ylim(bottom=0) ax2.set_ylim(bottom=0) - + def _get_additional_dim_combinations( self, data: xr.Dataset, @@ -1213,7 +1213,7 @@ def _plot_budget_allocation_bars( alpha=opacity, label="Channel Contribution", ) - + # Labels and formatting ax.set_xlabel("Channels") ax.set_ylabel("Allocated Spend", color="C0", labelpad=10) From 112851abe8750c596ea322f62d8a82dc5c6f1dce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Mon, 20 Oct 2025 23:31:14 +0300 Subject: [PATCH 3/5] add tests for double y-axis realignments --- tests/mmm/test_plot.py | 196 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index 6accbb698..9ad063c0d 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -14,6 +14,7 @@ import warnings import arviz as az +import matplotlib.pyplot as plt import numpy as np import pandas as pd import pytest @@ -716,3 +717,198 @@ def test_plot_sensitivity_analysis_error_on_missing_results(mock_idata): suite = MMMPlotSuite(idata=mock_idata) with pytest.raises(ValueError, match=r"No sensitivity analysis results found"): suite.plot_sensitivity_analysis() + + +class TestAlignYAxes: + """Test the _align_y_axes method with various axis configurations.""" + + @pytest.fixture + def mock_axes(self): + """Create mock axes for testing.""" + _, ax = plt.subplots() + ax2 = ax.twinx() + return ax, ax2 + + @pytest.mark.parametrize( + "ax1_limits,ax2_limits,expected_ax1_limits,expected_ax2_limits,description", + [ + ((0, 10), (0, 20), (0, 10), (0, 20), "both_positive"), + ((-10, 0), (-20, 0), (-10, 0), (-20, 0), "both_negative"), + ((-5, 10), (0, 20), (-5, 10), (-10, 20), "ax1_mixed_ax2_positive"), + ((-10, 0), (-5, 20), (-10, 40), (-5, 20), "ax1_negative_ax2_mixed"), + ((-2, 10), (-5, 20), (-2.5, 10), (-5, 20), "both_mixed_ax1_ratio_smaller"), + ((-5, 10), (-2, 20), (-5, 10), (-10, 20), "both_mixed_ax2_ratio_smaller"), + ((-10, 0), (0, 20), (-10, 10), (-20, 20), "ax1_fully_negative"), + ((-2, 0), (0, 10), (-2, 2), (-10, 10), "ax2_fully_negative"), + ], + ) + def test_align_y_axes_various_scenarios( + self, + mock_axes, + ax1_limits, + ax2_limits, + expected_ax1_limits, + expected_ax2_limits, + description, + ): + """Test _align_y_axes with various axis limit scenarios.""" + ax, ax2 = mock_axes + suite = MMMPlotSuite(az.InferenceData()) + + # Set initial limits + ax.set_ylim(ax1_limits) + ax2.set_ylim(ax2_limits) + + # Call the method + suite._align_y_axes(ax, ax2) + + # Check results + actual_ax1_limits = ax.get_ylim() + actual_ax2_limits = ax2.get_ylim() + + assert actual_ax1_limits == pytest.approx(expected_ax1_limits, rel=1e-10), ( + f"Axis 1 limits incorret in: {description}" + ) + assert actual_ax2_limits == pytest.approx(expected_ax2_limits, rel=1e-10), ( + f"Axis 2 limits incorret in: {description}" + ) + + +class TestPlotBudgetAllocationBars: + """Test the _plot_budget_allocation_bars method and its integration with _align_y_axes.""" + + @pytest.fixture + def mock_plot_suite(self): + """Create a mock MMMPlotSuite for testing.""" + return MMMPlotSuite(az.InferenceData()) + + @pytest.fixture + def sample_data(self): + """Create sample data for budget allocation plotting.""" + channels = np.array(["TV", "Radio", "Social Media", "Search", "Email"]) + allocated_spend = np.array([5000, 3000, 2000, 4000, 1000]) + channel_contribution = np.array([0.4, 0.2, 0.15, 0.2, 0.05]) + return channels, allocated_spend, channel_contribution + + def test_plot_budget_allocation_bars_basic(self, mock_plot_suite, sample_data): + """Test basic functionality of _plot_budget_allocation_bars.""" + channels, allocated_spend, channel_contribution = sample_data + _, ax = plt.subplots() + + # Should not raise any errors + mock_plot_suite._plot_budget_allocation_bars( + ax, channels, allocated_spend, channel_contribution + ) + + # Verify the plot was created + assert len(ax.patches) > 0 # Should have bar patches + assert ax.get_xlabel() == "Channels" + assert ax.get_ylabel() == "Allocated Spend" + + @pytest.mark.parametrize( + "allocated_spend,channel_contribution,description", + [ + ( + np.array([5000, 3000, 2000, 4000, 1000]), + np.array([0.4, 0.2, 0.15, 0.2, 0.05]), + "all_positive", + ), + ( + np.array([5000, 3000, 2000, 4000, -1000]), + np.array([0.4, 0.2, 0.15, 0.2, 0.05]), + "mixed_spend_positive_contribution", + ), + ( + np.array([-5000, -3000, -2000, -4000, -1000]), + np.array([-0.4, -0.2, -0.15, -0.2, -0.05]), + "all_negative", + ), + ( + np.array([5000, 3000, 2000, 4000, 1000]), + np.array([-0.4, -0.2, -0.15, -0.2, -0.05]), + "positive_spend_negative_contribution", + ), + ( + np.array([-5000, -3000, -2000, -4000, -1000]), + np.array([0.4, -0.2, -0.15, -0.2, -0.05]), + "mixed_both", + ), + ], + ) + def test_plot_budget_allocation_bars_axis_alignment( + self, mock_plot_suite, allocated_spend, channel_contribution, description + ): + """Test that _plot_budget_allocation_bars properly aligns y-axes for different data scenarios.""" + channels = np.array(["TV", "Radio", "Social Media", "Search", "Email"]) + _, ax = plt.subplots() + + # Call the method + mock_plot_suite._plot_budget_allocation_bars( + ax, channels, allocated_spend, channel_contribution + ) + + # Get the twin axis (created for channel contributions) + ax2 = ax.right_ax if hasattr(ax, "right_ax") else None + if ax2 is None: + # Find the twin axis manually + for other_ax in ax.figure.axes: + if ( + other_ax != ax + and hasattr(other_ax, "yaxis") + and other_ax.yaxis.get_scale() == ax.yaxis.get_scale() + ): + ax2 = other_ax + break + + assert ax2 is not None, f"Twin axis not found for {description}" + + # Verify that both axes exist and have been processed + assert ax.get_ylim() is not None + assert ax2.get_ylim() is not None + + # The method should have called _align_y_axes, so we can verify + # that the axes are in a reasonable state (no NaN or infinite values) + ax1_limits = ax.get_ylim() + ax2_limits = ax2.get_ylim() + + assert all(np.isfinite(ax1_limits)), ( + f"ax1 limits not finite for {description}: {ax1_limits}" + ) + assert all(np.isfinite(ax2_limits)), ( + f"ax2 limits not finite for {description}: {ax2_limits}" + ) + + def test_plot_budget_allocation_bars_include_zero_behavior(self, mock_plot_suite): + """Test that _plot_budget_allocation_bars uses include_zero=True for axis alignment.""" + channels = np.array(["TV", "Radio", "Social Media"]) + allocated_spend = np.array([5000, 3000, 2000]) + channel_contribution = np.array([0.4, 0.2, 0.15]) + + _, ax = plt.subplots() + + # Call the method + mock_plot_suite._plot_budget_allocation_bars( + ax, channels, allocated_spend, channel_contribution + ) + + # Find the twin axis + ax2 = None + for other_ax in ax.figure.axes: + if other_ax != ax and hasattr(other_ax, "yaxis"): + ax2 = other_ax + break + + assert ax2 is not None + + # Since include_zero=True is used, both axes should start at 0 or below + # (depending on whether there are negative values) + ax1_limits = ax.get_ylim() + ax2_limits = ax2.get_ylim() + + # For positive data, both should start at 0 + assert ax1_limits[0] <= 0, ( + f"ax1 should start at or below 0, got {ax1_limits[0]}" + ) + assert ax2_limits[0] <= 0, ( + f"ax2 should start at or below 0, got {ax2_limits[0]}" + ) From 772756f20ee658ef9def4bb9521330602a4bda7b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Mon, 20 Oct 2025 23:31:58 +0300 Subject: [PATCH 4/5] Refactor, and fix function. --- pymc_marketing/mmm/plot.py | 106 ++++++++++++++++++++++++++----------- 1 file changed, 75 insertions(+), 31 deletions(-) diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index e461b765a..85382eb0d 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -170,6 +170,7 @@ """ import itertools +from collections import namedtuple from collections.abc import Iterable import arviz as az @@ -243,39 +244,82 @@ def _build_subplot_title( return ", ".join(title_parts) return fallback_title - def _align_y_axes(self, ax, ax2, include_zero=False): + def _align_y_axes(self, ax_left, ax_right) -> None: """Align y=0 of primary and secondary y-axis.""" - if ax.axes.get_ylim()[0] < 0 or ax2.axes.get_ylim()[0] < 0: - ylims1 = ax.axes.get_ylim() - ylims2 = ax2.axes.get_ylim() - # Find the ratio of negative vs. positive part of the axes. - if ylims1[1]: - ax1_yratio = ylims1[0] / ylims1[1] - else: - # Fully negative axis. - ax1_yratio = -1 + # Store limits of both axes in named tuples. + YLimits = namedtuple("YLimits", ["bottom", "top"]) + ylims_left = YLimits(*ax_left.axes.get_ylim()) + ylims_right = YLimits(*ax_right.axes.get_ylim()) + + # Calculate the relative position of zero on both axes. + # 0 means all values are positive (zero is at the bottom) + # 1 means all values are negative (zero is at the top) + zero_rel_pos_left = -ylims_left.bottom / (ylims_left.top - ylims_left.bottom) + zero_rel_pos_right = -ylims_right.bottom / ( + ylims_right.top - ylims_right.bottom + ) - if ylims2[1]: - ax2_yratio = ylims2[0] / ylims2[1] + # If relative positions are equal, no action needed + if zero_rel_pos_left == zero_rel_pos_right: + return + + # If both axes include mixed values, edit one to match the other by solving + # rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom. + if zero_rel_pos_left not in [0, 1] and zero_rel_pos_right not in [0, 1]: + if zero_rel_pos_left < zero_rel_pos_right: + ax_left.set_ylim( + bottom=-zero_rel_pos_right + * ylims_left.top + / (1 - zero_rel_pos_right) + ) else: - # Fully negative axis, may need to reflect the other - ax2_yratio = -1 - - # Make axis adjustments. If both axes fully negative, no adjustment. - if ax1_yratio < ax2_yratio: - ax2.set_ylim(bottom=ylims2[1] * ax1_yratio) - if ax1_yratio == -1: - # if the axis is fully negative, center zero. - ax.set_ylim(top=-ylims1[0]) - elif ax2_yratio < ax1_yratio: - ax.set_ylim(bottom=ylims1[1] * ax2_yratio) - if ax2_yratio == -1: - # if the axis is fully negative, center zero. - ax2.set_ylim(top=-ylims2[0]) - elif include_zero: - # Ensure both axes start at zero - ax.set_ylim(bottom=0) - ax2.set_ylim(bottom=0) + ax_right.set_ylim( + bottom=-zero_rel_pos_left + * ylims_right.top + / (1 - zero_rel_pos_left) + ) + + # If one relative position is 1, edit the top by solving + # rel_pos_other = (0 - bottom) / (new_top - bottom) for new_top. + if zero_rel_pos_left == 1: + # Left axis is all negative, right axis has positive values + # if other axis is fully positive, place y=0 at the center of this axis. + ax_left.set_ylim( + top=ylims_left.bottom * (1 - 1 / zero_rel_pos_right) + if zero_rel_pos_right + else -ylims_left.bottom + ) + # Update lims and zero_rel_pos in case we need to edit the other axis. + ylims_left = YLimits(*ax_left.axes.get_ylim()) + zero_rel_pos_left = -ylims_left.bottom / ( + ylims_left.top - ylims_left.bottom + ) + elif zero_rel_pos_right == 1: + # Right axis is all negative, left axis has positive values + # if other axis is fully positive, place y=0 at the center of this axis. + ax_right.set_ylim( + top=ylims_right.bottom * (1 - 1 / zero_rel_pos_left) + if zero_rel_pos_left + else -ylims_right.bottom + ) + # Update lims and zero_rel_pos in case we need to edit the other axis. + ylims_right = YLimits(*ax_right.axes.get_ylim()) + zero_rel_pos_right = -ylims_right.bottom / ( + ylims_right.top - ylims_right.bottom + ) + + # If one relative position is 0, edit bottom by solving + # rel_pos_other = (0 - new_bottom) / (top - new_bottom) for new_bottom. + if zero_rel_pos_left == 0: + # Left axis is all positive, right axis has negative values + ax_left.set_ylim( + bottom=-zero_rel_pos_right * ylims_left.top / (1 - zero_rel_pos_right) + ) + elif zero_rel_pos_right == 0: + # Right axis is all positive, left axis has negative values + ax_right.set_ylim( + bottom=-zero_rel_pos_left * ylims_right.top / (1 - zero_rel_pos_left) + ) def _get_additional_dim_combinations( self, @@ -1225,7 +1269,7 @@ def _plot_budget_allocation_bars( ax.tick_params(axis="x", rotation=90) # Ensure that y=0 are aligned between ax and ax2. - self._align_y_axes(ax, ax2, include_zero=True) + self._align_y_axes(ax, ax2) # Turn off grid and add legend ax.grid(False) From 10dd9243a38ac4943a83878f5a005746a6fe1c45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Teemu=20S=C3=A4ilynoja?= Date: Tue, 21 Oct 2025 00:19:01 +0300 Subject: [PATCH 5/5] Remove outdated test --- tests/mmm/test_plot.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index 951b34efc..17f1adf97 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -1184,38 +1184,3 @@ def test_plot_budget_allocation_bars_axis_alignment( assert all(np.isfinite(ax2_limits)), ( f"ax2 limits not finite for {description}: {ax2_limits}" ) - - def test_plot_budget_allocation_bars_include_zero_behavior(self, mock_plot_suite): - """Test that _plot_budget_allocation_bars uses include_zero=True for axis alignment.""" - channels = np.array(["TV", "Radio", "Social Media"]) - allocated_spend = np.array([5000, 3000, 2000]) - channel_contribution = np.array([0.4, 0.2, 0.15]) - - _, ax = plt.subplots() - - # Call the method - mock_plot_suite._plot_budget_allocation_bars( - ax, channels, allocated_spend, channel_contribution - ) - - # Find the twin axis - ax2 = None - for other_ax in ax.figure.axes: - if other_ax != ax and hasattr(other_ax, "yaxis"): - ax2 = other_ax - break - - assert ax2 is not None - - # Since include_zero=True is used, both axes should start at 0 or below - # (depending on whether there are negative values) - ax1_limits = ax.get_ylim() - ax2_limits = ax2.get_ylim() - - # For positive data, both should start at 0 - assert ax1_limits[0] <= 0, ( - f"ax1 should start at or below 0, got {ax1_limits[0]}" - ) - assert ax2_limits[0] <= 0, ( - f"ax2 should start at or below 0, got {ax2_limits[0]}" - )