Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement and switch to lazy initval evaluation framework #4983

Merged
merged 10 commits into from
Oct 14, 2021

Conversation

michaelosthege
Copy link
Member

@michaelosthege michaelosthege commented Sep 4, 2021

Changes

  • The initial values are now evaluated lazily, all at once. Model.initial_values and Distribution(initval=...) can now take None, ndarray, Variable, "moment" or "prior".
  • Model.initial_values are now managed by RV instead of value var tensor
  • init_nuts signature changes
  • pm.sample() start kwarg change of name and meaning
  • Model.update_start_vals was removed (with informative error message)
  • All new initial point logic was moved to it's own module at initial_point.py

ToDo

  • Add additional tests for the new re-size flexibility that the lazy evaluation allows for. (test_initvals.py::TestInitvalEvaluation::test_initval_resizing)
  • Mention changes in the RELEASE-NOTES.md Doing this in the HackMD document

Closes #4924
Closes #4484

@codecov
Copy link

codecov bot commented Sep 4, 2021

Codecov Report

Merging #4983 (4e40756) into main (ab1178b) will increase coverage by 16.76%.
The diff coverage is 96.11%.

❗ Current head 4e40756 differs from pull request most recent head ff3b7f7. Consider uploading reports for the commit ff3b7f7 to get more accurate results
Impacted file tree graph

@@             Coverage Diff             @@
##             main    #4983       +/-   ##
===========================================
+ Coverage   61.45%   78.21%   +16.76%     
===========================================
  Files         130      131        +1     
  Lines       24461    24525       +64     
===========================================
+ Hits        15033    19183     +4150     
+ Misses       9428     5342     -4086     
Impacted Files Coverage Δ
pymc/variational/approximations.py 29.83% <0.00%> (+0.85%) ⬆️
pymc/sampling.py 87.60% <90.00%> (+0.71%) ⬆️
pymc/model.py 83.18% <92.30%> (+0.26%) ⬆️
pymc/tests/test_sampling.py 97.47% <93.33%> (+0.18%) ⬆️
pymc/distributions/continuous.py 95.83% <100.00%> (-0.10%) ⬇️
pymc/distributions/distribution.py 94.47% <100.00%> (+0.38%) ⬆️
pymc/initvalues.py 100.00% <100.00%> (ø)
pymc/tests/test_distributions.py 96.34% <100.00%> (+0.43%) ⬆️
pymc/tests/test_distributions_random.py 86.28% <100.00%> (ø)
pymc/tests/test_initvals.py 100.00% <100.00%> (+8.53%) ⬆️
... and 61 more

@michaelosthege
Copy link
Member Author

Could the remaining test failure pymc3/tests/test_ndarray_backend.py::TestSaveLoad::test_sample_posterior_predictive be related to seeds/rng of the initial values?

pymc3/model.py Outdated Show resolved Hide resolved
@michaelosthege
Copy link
Member Author

I added two more tests - dependent initvals and resizing works.

But something is odd with the random state. This is from the failing test_ndarray_backend.py::TestSaveLoad::test_sample_posterior_predictive:
image

That backend is deprecated and the test doesn't seem to be very targeted. It looks like the MultiTrace constructor accesses Model.initial_point and this random states used for the initval evaluation actually do interfer with the models rng_seeder.
@ricardoV94 can you take a look at the evaluation method?

Also what exactly do we want in terms of rng for initial value evaluation? Should it be independent, optionally with their own rng, or should it re-use the models rng?

@michaelosthege michaelosthege force-pushed the lazy-initval-evaluation branch 2 times, most recently from 94764d5 to 756acec Compare September 5, 2021 21:59
@michaelosthege michaelosthege added help wanted needs info Additional information required pytensor labels Sep 13, 2021
@michaelosthege michaelosthege force-pushed the lazy-initval-evaluation branch 3 times, most recently from 59f6c3c to 5f2ee98 Compare September 21, 2021 17:48
@aseyboldt aseyboldt force-pushed the lazy-initval-evaluation branch 3 times, most recently from 4b25e3d to f5aa61a Compare September 21, 2021 20:58
@aseyboldt
Copy link
Member

aseyboldt commented Sep 21, 2021

The remaining test failures seem to be a result of #5007, but hopefully that would be fixed by #4887, so maybe we wait for that to be merged?

UPDATE: I marked them as XFAIL to unblock this PR.

pymc3/distributions/discrete.py Outdated Show resolved Hide resolved
pymc3/distributions/distribution.py Outdated Show resolved Hide resolved
pymc3/model.py Outdated Show resolved Hide resolved
pymc3/model.py Outdated Show resolved Hide resolved
pymc3/model.py Outdated Show resolved Hide resolved
pymc3/model.py Outdated Show resolved Hide resolved
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 22, 2021

Good time to update the brittle model.update_start_vals

https://github.com/pymc-devs/pymc3/blob/4f8ad5d2772c8fd98b403d7c1d141e8f8394ba84/pymc3/model.py#L1594

There's more context here (with a different pseudo-API): #4924 (comment)

@twiecki
Copy link
Member

twiecki commented Oct 13, 2021

LGTM. What happens when an init point is -inf?

@ricardoV94
Copy link
Member

LGTM. What happens when an init point is -inf?

What do you mean? The model logp at the initial point? If so, same as before

@michaelosthege
Copy link
Member Author

Both pre-commit and docs worked on master. But that should be easy to fix(?)

@ricardoV94
Copy link
Member

Both pre-commit and docs worked on master. But that should be easy to fix(?)

Rebased after #5070. Should be fixed now

@ricardoV94 ricardoV94 force-pushed the lazy-initval-evaluation branch 3 times, most recently from ea49486 to 1f6e8f9 Compare October 14, 2021 09:05
ricardoV94 and others added 10 commits October 14, 2021 13:52
With this commit "moment" or "prior" become legal initvals.
Furthermore rv.tag.test_value is no longer assigned or used for initvals.

The tolerance on test_mle_jacobian was eased to account for non-
deterministic starting points of the optimization.
…odel

This function can also handle variable specific jittering and user defined overrides

The pm.sampling module was adapted to use the new functionality.
This changed the signature of `init_nuts`:
+ `start` kwarg becomes `initvals`
+ `initvals` are required to be complete for all chains
+ `seeds` can now be specified for all chains
The test relied on monkey patching the jitter so that the model initial logp would fail predictably.
This does not seem to be possible with the new numpy random generators, so a different test strategy has to be developed
To unblock this PR/branch from the aeppl integration.
Copy link
Member Author

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't approve since I'm the original author.

That one nitpick comment I made shouldn't hold us back from anything

rvs_to_jitter : set
The random variables for which jitter should be added.
"""
# TODO: implement this
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created an issue for this: #5077

Comment on lines +150 to +161
def find_rng_nodes(variables):
return [
node
for node in graph_inputs(variables)
if isinstance(
node,
(
at.random.var.RandomStateSharedVariable,
at.random.var.RandomGeneratorSharedVariable,
),
)
]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be extracted

@twiecki twiecki merged commit 3b763a1 into pymc-devs:main Oct 14, 2021
@twiecki
Copy link
Member

twiecki commented Oct 14, 2021

🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Initval refactoring Invalid logic in sample prior predictive for transformed variables
5 participants