Skip to content

Commit

Permalink
remove warnings in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jul 10, 2024
1 parent 40aeee9 commit 2be2664
Showing 1 changed file with 50 additions and 54 deletions.
104 changes: 50 additions & 54 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
import xarray as xr
from matplotlib import pyplot as plt

from pymc_marketing.mmm.components.adstock import DelayedAdstock
from pymc_marketing.mmm.components.saturation import MichaelisMentenSaturation
from pymc_marketing.mmm.components.adstock import DelayedAdstock, GeometricAdstock
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
MichaelisMentenSaturation,
)
from pymc_marketing.mmm.delayed_saturated_mmm import MMM, BaseMMM, DelayedSaturatedMMM
from pymc_marketing.prior import Prior

Expand Down Expand Up @@ -117,10 +120,9 @@ def mmm() -> MMM:
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)


Expand All @@ -129,10 +131,9 @@ def mmm_with_fourier_features() -> MMM:
return MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
yearly_seasonality=2,
)

Expand Down Expand Up @@ -191,21 +192,23 @@ def deep_equal(dict1, dict2):
return False
return True

l_max = 4
adstock = GeometricAdstock(l_max=l_max)
saturation = LogisticSaturation()
model = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
model_config=model_config_requiring_serialization,
adstock="geometric",
saturation="logistic",
adstock=adstock,
saturation=saturation,
)
model = mock_fit(model, toy_X, toy_y)
model.save("test_save_load")
model2 = MMM.load("test_save_load")
assert model.date_column == model2.date_column
assert model.control_columns == model2.control_columns
assert model.channel_columns == model2.channel_columns
assert model.adstock_max_lag == model2.adstock_max_lag
assert model.adstock.l_max == model2.adstock.l_max
assert model.validate_data == model2.validate_data
assert model.yearly_seasonality == model2.yearly_seasonality
assert deep_equal(model.model_config, model2.model_config)
Expand Down Expand Up @@ -264,12 +267,11 @@ def test_init(
date_column="date",
channel_columns=channel_columns,
control_columns=control_columns,
adstock_max_lag=adstock_max_lag,
yearly_seasonality=yearly_seasonality,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=adstock_max_lag),
saturation=LogisticSaturation(),
)
mmm.build_model(X=toy_X, y=toy_y)
n_channel: int = len(mmm.channel_columns)
Expand Down Expand Up @@ -342,10 +344,9 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
date_column="date",
channel_columns=["channel_1", "channel_2"],
control_columns=["control_1", "control_2"],
adstock_max_lag=2,
yearly_seasonality=2,
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
)
assert mmm.version == "0.0.3"
assert mmm._model_type == "BaseValidateMMM"
Expand Down Expand Up @@ -457,10 +458,9 @@ def test_get_errors_raises_not_fitted(self) -> None:
my_mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
with pytest.raises(
RuntimeError,
Expand All @@ -472,10 +472,9 @@ def test_posterior_predictive_raises_not_fitted(self) -> None:
my_mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
with pytest.raises(
RuntimeError,
Expand Down Expand Up @@ -575,9 +574,8 @@ def test_data_setter(self, toy_X, toy_y):
base_delayed_saturated_mmm = BaseMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
base_delayed_saturated_mmm = mock_fit(base_delayed_saturated_mmm, toy_X, toy_y)

Expand Down Expand Up @@ -617,7 +615,7 @@ def test_save_load(self, mmm_fitted: MMM):
assert model.date_column == model2.date_column
assert model.control_columns == model2.control_columns
assert model.channel_columns == model2.channel_columns
assert model.adstock_max_lag == model2.adstock_max_lag
assert model.adstock.l_max == model2.adstock.l_max
assert model.validate_data == model2.validate_data
assert model.yearly_seasonality == model2.yearly_seasonality
assert model.model_config == model2.model_config
Expand All @@ -633,9 +631,8 @@ def mock_property(self):
DSMMM = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)

# Check that the property returns the new value
Expand Down Expand Up @@ -686,11 +683,10 @@ def test_model_config(
model = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=2,
yearly_seasonality=2,
model_config=model_config,
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
)

model.build_model(X=toy_X, y=toy_y.to_numpy())
Expand Down Expand Up @@ -831,7 +827,7 @@ def new_contributions_property_checks(new_contributions, X, model):
assert coords["channel"].values.tolist() == model.channel_columns
np.testing.assert_allclose(
coords["time_since_spend"].values,
np.arange(-model.adstock_max_lag, model.adstock_max_lag + 1),
np.arange(-model.adstock.l_max, model.adstock.l_max + 1),
)

# Channel contributions are non-negative
Expand All @@ -849,10 +845,9 @@ def test_new_spend_contributions_prior_error() -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=GeometricAdstock(l_max=4),
saturation=LogisticSaturation(),
)
new_spend = np.ones(len(mmm.channel_columns))
match = "sample_prior_predictive"
Expand Down Expand Up @@ -957,13 +952,14 @@ def test_add_lift_test_measurements(mmm, toy_X, toy_y, df_lift_test) -> None:


def test_add_lift_test_measurements_no_model() -> None:
adstock = GeometricAdstock(l_max=4)
saturation = LogisticSaturation()
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
adstock=adstock,
saturation=saturation,
)
with pytest.raises(RuntimeError, match="The model has not been built yet."):
mmm.add_lift_test_measurements(
Expand All @@ -984,14 +980,15 @@ def test_delayed_saturated_mmm_raises_deprecation_warning() -> None:


def test_initialize_alternative_with_strings() -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="delayed",
saturation="michaelis_menten",
)
with pytest.warns(DeprecationWarning):
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="delayed",
saturation="michaelis_menten",
)

assert isinstance(mmm.adstock, DelayedAdstock)
assert mmm.adstock.l_max == 4
Expand All @@ -1002,7 +999,6 @@ def test_initialize_alternative_with_classes() -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock=DelayedAdstock(l_max=10),
saturation=MichaelisMentenSaturation(),
Expand All @@ -1017,7 +1013,6 @@ def test_initialize_defaults_channel_media_dims() -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock=DelayedAdstock(l_max=10),
saturation=MichaelisMentenSaturation(),
Expand All @@ -1039,12 +1034,13 @@ def test_initialize_defaults_channel_media_dims() -> None:
def test_save_load_with_tvp(
time_varying_intercept, time_varying_media, toy_X, toy_y
) -> None:
adstock = GeometricAdstock(l_max=5)
saturation = LogisticSaturation()
mmm = MMM(
channel_columns=["channel_1", "channel_2"],
date_column="date",
adstock="geometric",
saturation="logistic",
adstock_max_lag=5,
adstock=adstock,
saturation=saturation,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
)
Expand Down

0 comments on commit 2be2664

Please sign in to comment.