Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Callback for divergences #838

Open
wd60622 opened this issue Jul 17, 2024 Discussed in #837 · 8 comments
Open

Callback for divergences #838

wd60622 opened this issue Jul 17, 2024 Discussed in #837 · 8 comments
Labels

Comments

@wd60622
Copy link
Contributor

wd60622 commented Jul 17, 2024

Discussed in #837

Originally posted by AlfredoJF July 17, 2024
Hi,

Is it possible to implement callbacks in pymc-marketing? If so, how can I create a callback to monitor the number of divergences?

I'd like to use this as an early stopping when divergences are greater than 10 but the resources for PyMC seemed outdated

Thanks!

@wd60622 wd60622 added the MMM label Jul 17, 2024
@wd60622
Copy link
Contributor Author

wd60622 commented Jul 17, 2024

Is this a feature of PyMC? Can you link the resources (even if they are outdated)

EDIT: Is it this callback parameter? https://www.pymc.io/projects/docs/en/latest/api/generated/pymc.sample.html
Can this not be passed with kwargs to pymc-marketing fit method? https://www.pymc-marketing.io/en/stable/api/generated/pymc_marketing.mmm.delayed_saturated_mmm.MMM.fit.html#pymc_marketing.mmm.delayed_saturated_mmm.MMM.fit

CC @AlfredoJF

@AlfredoJF
Copy link

Hi @wd60622

This is one of the examples I found in the PyMC docs https://www.pymc.io/projects/examples/en/2022.12.0/howto/sampling_callback.html

def my_callback(trace, draw):
    if len(trace) >= 100:
        raise KeyboardInterrupt()


with model:
    trace = pm.sample(tune=0, draws=500, callback=my_callback, chains=1)

print(len(trace))

And also this old response from junpenglao https://discourse.pymc.io/t/using-callbacks-in-3-11-2-to-test-for-divergences/7807

class DivergentEarlyStopping:
    def __init__(self):
        self.count = 0

    def __call__(self, trace, draw):
        if draw.tuning:
            return

        self.count += int(draw.stats[0]['diverging'])
        if self.count > 10:
            raise KeyboardInterrupt()

callback = DivergentEarlyStopping()

with pm.Model():
    y = pm.Normal('y', mu=0., sigma=3.)
    x = pm.Normal('x', mu=0., sigma=pm.math.exp(y/2), shape=9)
    trace = pm.sample(1000, callback=callback)

I tried both a part of the MMM class and in the fit method as params and **kwargs.

If I add my_callback to the sample_config dictionary I get the below error:

sampler_config= {"progressbar": True,
                               "callback": my_callback
                 }

Error:

Running chain 0: 100%
 1000/1000 [01:56<00:00, 19.79it/s]
Running chain 1: 100%
 1000/1000 [01:56<00:00, 15.09it/s]
ERROR:pymc.stats.convergence:There were 22 divergences after tuning. Increase `target_accept` or reparameterize.
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-37-7b73cd73e8a8>](https://localhost:8080/#) in <cell line: 2>()
      1 # Fit the model on the dataset
----> 2 model.fit(
      3     X=X_train,
      4     y=y_train,
      5     target_accept=0.95,  # comment out or use default 0.95

5 frames
[/usr/local/lib/python3.10/dist-packages/pymc_marketing/model_builder.py](https://localhost:8080/#) in fit(self, X, y, progressbar, predictor_names, random_seed, **kwargs)
    492             )
    493             self.idata.add_groups(fit_data=combined_data.to_xarray())  # type: ignore
--> 494         self.set_idata_attrs(self.idata)
    495         return self.idata  # type: ignore
    496 

[/usr/local/lib/python3.10/dist-packages/pymc_marketing/model_builder.py](https://localhost:8080/#) in set_idata_attrs(self, idata)
    292         idata.attrs["model_type"] = self._model_type
    293         idata.attrs["version"] = self.version
--> 294         idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
    295         idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
    296         # Only classes with non-dataset parameters will implement save_input_params

[/usr/lib/python3.10/json/__init__.py](https://localhost:8080/#) in dumps(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, default, sort_keys, **kw)
    229         cls is None and indent is None and separators is None and
    230         default is None and not sort_keys and not kw):
--> 231         return _default_encoder.encode(obj)
    232     if cls is None:
    233         cls = JSONEncoder

[/usr/lib/python3.10/json/encoder.py](https://localhost:8080/#) in encode(self, o)
    197         # exceptions aren't as detailed.  The list call should be roughly
    198         # equivalent to the PySequence_Fast that ''.join() would do.
--> 199         chunks = self.iterencode(o, _one_shot=True)
    200         if not isinstance(chunks, (list, tuple)):
    201             chunks = list(chunks)

[/usr/lib/python3.10/json/encoder.py](https://localhost:8080/#) in iterencode(self, o, _one_shot)
    255                 self.key_separator, self.item_separator, self.sort_keys,
    256                 self.skipkeys, _one_shot)
--> 257         return _iterencode(o, 0)
    258 
    259 def _make_iterencode(markers, _default, _encoder, _indent, _floatstr,

[/usr/lib/python3.10/json/encoder.py](https://localhost:8080/#) in default(self, o)
    177 
    178         """
--> 179         raise TypeError(f'Object of type {o.__class__.__name__} '
    180                         f'is not JSON serializable')
    181 

TypeError: Object of type function is not JSON serializable

@wd60622
Copy link
Contributor Author

wd60622 commented Jul 18, 2024

That seems like a model builder bug then. Good catch!

We will likely have to write a customer encoder or have your class work with this.
https://docs.python.org/3/library/json.html#json.JSONEncoder

You can currently work around this by using kwargs at fit rather than storing this in sampler_config. i.e.

mmm = MMM(
    ..., 
    # Only what is json serial
    sampler_config={...}, 
)
mmm.fit(X, y, callback=my_callback)

Or overriding the set_idata_attrs method.

class NewMMM(MMM): 
    def set_idata_attrs(self, idata):
        # Previous code except for handle sampler_config

The first solution will likely be easier. Give that one a try and let me know how that goes

@wd60622
Copy link
Contributor Author

wd60622 commented Jul 18, 2024

Any thoughts on this @juanitorduz
Should we catch and raise more informative error? Think handling all types of keys would be impossible.

@AlfredoJF
Copy link

@wd60622 Using the first option only works for pymc sampler. I was using numpyro but seems it doesn't work properly

mmm = MMM(
    ..., 
    # Only what is json serial
    sampler_config={...}, 
)
mmm.fit(X, y, callback=my_callback)

@wd60622
Copy link
Contributor Author

wd60622 commented Jul 19, 2024

@wd60622 Using the first option only works for pymc sampler. I was using numpyro but seems it doesn't work properly

mmm = MMM(

..., 
# Only what is json serial
sampler_config={...}, 

)

mmm.fit(X, y, callback=my_callback)

Can you share the code? If numpyro sampler doesnt support callback function, thatd be pymc functionality

@AlfredoJF
Copy link

AlfredoJF commented Jul 24, 2024

Hi @wd60622

Here is the code that worked for me using pymc sampler.

class DivergentEarlyStopping:
  # only for pymc sampler (for now)
  def __init__(self, max_divergences):
    self.count = 0
    self.max_divergences = max_divergences

  def __call__(self, trace, draw):
    if draw.tuning:
      return

    self.count += int(draw.stats[0]['diverging'])
    if self.count > self.max_divergences:
      raise RuntimeError(f"My RuntimeError: Early stopping activated! {self.count} > {self.max_divergences}")

callback = DivergentEarlyStopping(10)

mmm = MMM(
    ..., 
)

mmm.fit(X, y, callback=callback, nuts_sampler="pymc")

@AlfredoJF
Copy link

It seems callback for nuts_sampler="numpyro" is not supported at the moment. See here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants