Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615038294
Change-Id: I7c0658b0a13506c319fd3e6e00cdf2791d64e26f
  • Loading branch information
fehiepsi authored and copybara-github committed Mar 12, 2024
1 parent 03bd82d commit 4406aaa
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions lightweight_mmm/lightweight_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"carryover": models.transform_carryover
})
_MODEL_FUNCTION = models.media_mix_model
_DETERMINISTIC_VARIABLES = ("media_transformed", "mu")


def _compare_equality_for_lmmm(item_1: Any, item_2: Any) -> bool:
Expand Down Expand Up @@ -440,6 +441,10 @@ def _predict(
Returns:
The predictions for the given data.
"""
# Remove deterministic variables like "mu" from the posterior.
posterior_samples = posterior_samples.copy()
for name in _DETERMINISTIC_VARIABLES:
posterior_samples.pop(name, None)
return infer.Predictive(
model=model, posterior_samples=posterior_samples)(
rng_key=rng_key,
Expand Down

0 comments on commit 4406aaa

Please sign in to comment.