-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug with multiple minibatch variables #7408
Fix bug with multiple minibatch variables #7408
Conversation
b47c935
to
f1e3d9c
Compare
Fixes bug in VI with multiple Minibatch variables, which occurred due to separate calls to model.logp (from model.datalogp and model.varlogp) that create distinct clones of the RandomIntegersRV underlying minibatch slicing. `compile_pymc` would not set any updates in this case
f1e3d9c
to
10f3aef
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7408 +/- ##
==========================================
- Coverage 92.19% 92.18% -0.02%
==========================================
Files 103 103
Lines 17212 17261 +49
==========================================
+ Hits 15869 15912 +43
- Misses 1343 1349 +6
|
total_size=len(y), | ||
) | ||
mean_field = pm.fit(10_000, obj_optimizer=pm.adam(learning_rate=0.01), progressbar=False) | ||
np.testing.assert_allclose(mean_field.mean.get_value(), true_weights, rtol=1e-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this test need to run the whole model and check parameter recovery? It should be enough to compile the function and check that minibatch_feature
and minibatch_y
change after each loss function execution right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should run pretty fast, with minibatch of 1. I think it's a useful integration test, we didn't have any linear regression minibatch test in the codebase.
Besides VI has a very complex logic leading to building the function that I rather treat as a black box
Description
Reported in https://discourse.pymc.io/t/verifying-that-minibatch-is-actually-randomly-sampling/14308
The bug occurred due to separate calls to model.logp (from model.datalogp and model.varlogp) that create distinct clones of the RandomIntegersRV underlying minibatch slicing.
compile_pymc
would not set any updates in this caseRelated Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7408.org.readthedocs.build/en/7408/