Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement more robust jitter init (resolves #4107) #4298

Merged
merged 4 commits into from
Dec 5, 2020

Conversation

ricardoV94
Copy link
Member

This PR addresses issue #4107, by allowing the starting jitter to be resampled when the sampled values generate an invalid probability for the model. There is a new optional argument jitter_max_retries in sample() and init_nuts() that controls the maximum number of times that a value can be resampled (per chain) before it gives up and returns whatever was last sampled. I arbitrarily set it to 10, but we can choose another default.

I further refactored the code that applies jitter to the starting point of each chain into a helper function _init_jitter(), to avoid duplicated code between the two init methods where this is used init="jitter+adapt_diag" and init="jitter+adapt_diag". I added a unit_test for this function.

Here is an example that (almost deterministically) shows an improvement following this PR:

import pymc3 as pm

with pm.Model() as m:
    x = pm.HalfNormal('x', transform=None)

try:
    with m:
        trace = pm.sample(tune=1, draws=1, chains=100, jitter_max_retries=0,
                          compute_convergence_checks=False, progressbar=False)
except pm.exceptions.SamplingError:
    print('Exception raised as expected')


with m:
    trace = pm.sample(tune=1, draws=1, chains=100, jitter_max_retries=10,
                      compute_convergence_checks=False, progressbar=False)
print('Exception not raised as expected')

If you happen to know other examples of models that have fragile starting points when jitter is applied it would be great to test it out.

Any thoughts?

@codecov
Copy link

codecov bot commented Dec 5, 2020

Codecov Report

Merging #4298 (7c95082) into master (3fa3d1f) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #4298   +/-   ##
=======================================
  Coverage   87.69%   87.69%           
=======================================
  Files          88       88           
  Lines       14355    14360    +5     
=======================================
+ Hits        12588    12593    +5     
  Misses       1767     1767           
Impacted Files Coverage Δ
pymc3/sampling.py 87.65% <100.00%> (+0.07%) ⬆️

pymc3/sampling.py Outdated Show resolved Hide resolved
pymc3/sampling.py Outdated Show resolved Hide resolved
pymc3/sampling.py Outdated Show resolved Hide resolved
@twiecki twiecki merged commit 580a32a into pymc-devs:master Dec 5, 2020
@ricardoV94 ricardoV94 deleted the robust_jitter_init branch December 6, 2020 06:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants