Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tuning by skipping the first samples + add new experimental tuning method #5004

Merged
merged 12 commits into from
Sep 22, 2021

Conversation

aseyboldt
Copy link
Member

Tuning the mass matrix starts right away in the current implementation, but in most models during the first couple of samples we are only moving to the typical set, so we do not get information about the posterior variance at all. In the worst case we learn a mass matrix that doesn't match the posterior at all, so that sampling the the first adaptation window will be very slow (you can see this a slowdown of sampling after step 100). Usually, we will recover from this, but it seems to be better to just skip those samples during adaptation in the first place.

In an example model by @ricardoV94 we can see this behavior clearly when we look at the distance of the currently used mass matrix to the final mass matrix:

image

This PR also contains an experimental tuning implementation using gradients and samples that can be enabled by init="jitter+adapt_diag_grad". During tests on a few models this seems to be more stable than the only sample based tuning system we use right now, but there are also a few cases where it performs worse. For posteriors that are normal it should converge to the same mass matrix as our current implementation (and much faster), but for non-normal posteriors the result can differ. Unfortunately I don't know of any other way to tell which is better other than trying it on a large number of models.

An example notebook can be found here:
https://gist.github.com/aseyboldt/7897fbddacacaa0c86efc917afe9ce3f

@@ -342,6 +360,8 @@ def __init__(

def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("weight is unused and broken")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("weight is unused and broken")
raise ValueError("Setting weight != 1 is not supported.")

Or maybe we should just remove it all-together.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -87,7 +87,7 @@ def test_sample(self):

def test_sample_init(self):
with self.model:
for init in ("advi", "advi_map", "map"):
for init in ("advi", "advi_map", "map", "jitter+adapt_diag_grad"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add all the others here too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@twiecki twiecki changed the title Improve tuning by skipping the first samples Improve tuning by skipping the first samples + add new experimental tuning method Sep 20, 2021
@twiecki
Copy link
Member

twiecki commented Sep 20, 2021

Needs a line in the release-notes.

@codecov

This comment has been minimized.

@rlouf
Copy link
Contributor

rlouf commented Sep 20, 2021

I noticed something similar when debugging blackjax's warmup. This is great ! It should also be useful for aehmc.

@twiecki

This comment has been minimized.

RELEASE-NOTES.md Outdated Show resolved Hide resolved
pymc3/sampling.py Outdated Show resolved Hide resolved
pymc3/sampling.py Outdated Show resolved Hide resolved
Comment on lines 361 to 364
def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("Setting weight != 1 is not supported.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("Setting weight != 1 is not supported.")
def add_sample(self, x, weight=None):
if weight is not None:
warning.warn(
"Setting weight is no longer supported and and will raise an error in the future.",
DeprecationWarning,
)
x = np.asarray(x)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a hard break is fine here. This really was internal, unused and wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I would suggest removing the weight argument altogether

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a couple of comments

@ricardoV94 ricardoV94 merged commit 4f8ad5d into pymc-devs:main Sep 22, 2021
@ricardoV94
Copy link
Member

Thanks @aseyboldt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants