Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Break ModelBuilder into smaller classes #1467

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 8 additions & 61 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import json
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import cast

import arviz as az
Expand All @@ -29,7 +28,6 @@

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
from pymc_marketing.utils import from_netcdf


class CLVModel(ModelBuilder):
Expand Down Expand Up @@ -165,55 +163,13 @@
return pm.sample(step=pm.DEMetropolisZ(), **sampler_config)

@classmethod
def load(cls, fname: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1356 should be considered for the load method replacing this.

"""Create a ModelBuilder instance from a file.

Loads inference data for the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.

Returns
-------
Returns an instance of ModelBuilder.

Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.

Examples
--------
>>> class MyModel(ModelBuilder):
>>> ...
>>> name = './mymodel.nc'
>>> imported_model = MyModel.load(name)

"""
filepath = Path(str(fname))
idata = from_netcdf(filepath)
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
)
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
model.build_model() # type: ignore
if model.id != idata.attrs["id"]:
raise ValueError(f"Inference data not compatible with {cls._model_type}")
return model
def idata_to_init_kwargs(cls, idata: az.InferenceData) -> dict:
"""Create the initialization kwargs from an InferenceData object."""
return {

Check warning on line 168 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L168

Added line #L168 was not covered by tests
"data": idata.fit_data.to_dataframe(),
"model_config": json.loads(idata.attrs["model_config"]),
"sampler_config": json.loads(idata.attrs["sampler_config"]),
}

def thin_fit_result(self, keep_every: int):
"""Return a copy of the model with a thinned fit result.
Expand Down Expand Up @@ -244,7 +200,7 @@
self.fit_result # noqa: B018 (Raise Error if fit didn't happen yet)
assert self.idata is not None # noqa: S101
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)
return self.build_from_idata(new_idata)

Check warning on line 203 in pymc_marketing/clv/models/basic.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/clv/models/basic.py#L203

Added line #L203 was not covered by tests

@property
def default_sampler_config(self) -> dict:
Expand All @@ -267,12 +223,3 @@
return res["mean"].rename("value")
else:
return az.summary(self.fit_result, **kwargs)

@property
def output_var(self):
"""Output variable of the model."""
pass

def _data_setter(self):
"""Set the data for the model."""
pass
4 changes: 2 additions & 2 deletions pymc_marketing/customer_choice/mv_its.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
from typing_extensions import Self
from xarray import DataArray

from pymc_marketing.model_builder import ModelBuilder, create_idata_accessor
from pymc_marketing.model_builder import RegressionModelBuilder, create_idata_accessor
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior

HDI_ALPHA = 0.5


class MVITS(ModelBuilder):
class MVITS(RegressionModelBuilder):
"""Multivariate Interrupted Time Series class.

Class to perform a multivariate interrupted time series analysis with the
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@
ValidateDateColumn,
ValidateTargetColumn,
)
from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_builder import RegressionModelBuilder

__all__ = ["BaseValidateMMM", "MMMModelBuilder"]

from pydantic import Field, validate_call


class MMMModelBuilder(ModelBuilder):
class MMMModelBuilder(RegressionModelBuilder):
"""Base class for Marketing Mix Models (MMM)."""

model: pm.Model
Expand Down
Loading
Loading