From 2d2c3113973c2e1600cb36a7cb264400e0aa834a Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 3 Sep 2025 10:21:17 +0200 Subject: [PATCH 1/3] Enforce MuEffect as pydantic subclass to support serialization/deserialization when saving and loading models --- pymc_marketing/mmm/additive_effect.py | 55 +++++++++++++++------------ 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/pymc_marketing/mmm/additive_effect.py b/pymc_marketing/mmm/additive_effect.py index 6236f29ea..bf4bbcb2c 100644 --- a/pymc_marketing/mmm/additive_effect.py +++ b/pymc_marketing/mmm/additive_effect.py @@ -104,12 +104,13 @@ def set_data(self, mmm, model, X): - In `set_data`, update the data variables when dates/dims change. """ -from typing import Any, Protocol +from abc import ABC, abstractmethod +from typing import Annotated, Any, Protocol import pandas as pd import pymc as pm import xarray as xr -from pydantic import BaseModel, InstanceOf +from pydantic import BaseModel, Field, InstanceOf, PlainValidator, WithJsonSchema from pymc_extras.prior import create_dim_handler from pytensor import tensor as pt @@ -131,35 +132,31 @@ def model(self) -> pm.Model: """The PyMC model.""" -class MuEffect(Protocol): - """Protocol for arbitrary additive mu effect.""" +class MuEffect(ABC, BaseModel): + """Abstract base class for arbitrary additive mu effects. + All mu_effects must inherit from this Pydantic BaseModel to ensure proper + serialization and deserialization when saving/loading MMM models. + """ + + @abstractmethod def create_data(self, mmm: Model) -> None: """Create the required data in the model.""" + @abstractmethod def create_effect(self, mmm: Model) -> pt.TensorVariable: """Create the additive effect in the model.""" + @abstractmethod def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: """Set the data for new predictions.""" -class FourierEffect: +class FourierEffect(MuEffect): """Fourier seasonality additive effect for MMM.""" - def __init__(self, fourier: FourierBase, date_dim_name: str = "date"): - """Initialize the Fourier effect. - - Parameters - ---------- - fourier : FourierBase - The FourierBase instance to use for the effect. - date_dim_name : str, optional - The name of the date dimension in the model, by default "date". - - """ - self.fourier = fourier - self.date_dim_name: str = date_dim_name + fourier: InstanceOf[FourierBase] + date_dim_name: str = Field("date") def create_data(self, mmm: Model) -> None: """Create the required data in the model. @@ -247,7 +244,14 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: pm.set_data(new_data=new_data, model=model) -class LinearTrendEffect: +_Timestamp = Annotated[ + pd.Timestamp, + PlainValidator(lambda x: pd.Timestamp(x)), + WithJsonSchema({"type": "date-time"}), +] + + +class LinearTrendEffect(MuEffect): """Wrapper for LinearTrend to use with MMM's MuEffect protocol. This class adapts the LinearTrend component to be used as an additive effect @@ -259,6 +263,8 @@ class LinearTrendEffect: The LinearTrend instance to wrap. prefix : str The prefix to use for variables in the model. + date_dim_name : str + The name of the date dimension in the model. Examples -------- @@ -348,11 +354,10 @@ class MockMMM: """ - def __init__(self, trend: LinearTrend, prefix: str, date_dim_name: str = "date"): - self.trend = trend - self.prefix = prefix - self.linear_trend_first_date: pd.Timestamp - self.date_dim_name: str = date_dim_name + trend: InstanceOf[LinearTrend] + prefix: str + date_dim_name: str = Field("date") + linear_trend_first_date: _Timestamp | None = Field(None, init=False) def create_data(self, mmm: Model) -> None: """Create the required data in the model. @@ -430,7 +435,7 @@ def set_data(self, mmm: Model, model: pm.Model, X: xr.Dataset) -> None: pm.set_data({f"{self.prefix}_t": t}, model=model) -class EventAdditiveEffect(BaseModel): +class EventAdditiveEffect(MuEffect): """Event effect class for the MMM. Parameters From fe2eba19aa3b6e03c695085b623dbf218c56a4bb Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 3 Sep 2025 16:24:28 +0200 Subject: [PATCH 2/3] Update test_fourier_multidimensional to comply with the new pydantic class --- tests/mmm/test_additive_effect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mmm/test_additive_effect.py b/tests/mmm/test_additive_effect.py index 483d50b41..c6a2d83f5 100644 --- a/tests/mmm/test_additive_effect.py +++ b/tests/mmm/test_additive_effect.py @@ -168,7 +168,7 @@ def test_fourier_effect_multidimensional( prefix = "weekly" prior = Prior("Laplace", mu=0, b=0.1, dims=prior_dims) fourier = WeeklyFourier(n_order=10, prefix=prefix, prior=prior) - fourier_effect = FourierEffect(fourier) + fourier_effect = FourierEffect(fourier=fourier) with mmm.model: fourier_effect.create_data(mmm) From 520db5848925bb91e9433e8fafe934214aa080a2 Mon Sep 17 00:00:00 2001 From: PabloRoque Date: Wed, 3 Sep 2025 16:38:56 +0200 Subject: [PATCH 3/3] Improve tests to comply with new interface --- tests/mmm/test_additive_effect.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mmm/test_additive_effect.py b/tests/mmm/test_additive_effect.py index c6a2d83f5..548243537 100644 --- a/tests/mmm/test_additive_effect.py +++ b/tests/mmm/test_additive_effect.py @@ -98,7 +98,7 @@ def test_fourier_effect( dims, coords, ) -> None: - effect = FourierEffect(fourier) + effect = FourierEffect(fourier=fourier) mmm = create_mock_mmm( dims=dims, @@ -215,7 +215,7 @@ def test_linear_trend_effect( ) -> None: prefix = "linear_trend" effect = LinearTrendEffect( - LinearTrend(priors=priors, dims=linear_trend_dims), + trend=LinearTrend(priors=priors, dims=linear_trend_dims), prefix=prefix, )