-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Natgrads #90
Natgrads #90
Conversation
@thomaspinder it would be good to get your thoughts on the code so far (some of it rather rough). I believe the implementation works. The main issue is that the notebook needs writing before merging, I'll get on with this. Also some tests need writing. I'll need to test some benchmarks against |
Codecov Report
@@ Coverage Diff @@
## v0.5_update #90 +/- ##
==============================================
Coverage ? 99.22%
==============================================
Files ? 14
Lines ? 1154
Branches ? 0
==============================================
Hits ? 1145
Misses ? 9
Partials ? 0
Flags with carried forward coverage won't be shown. Click here to find out more. 📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Started a review but will finish more tomorrow.
@thomaspinder I have rebased the branch with master and would appreciate a review. On my end, I need to:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Several comments here, many of them suggestions, not hard requires. A general comment is that more information is needed in the docstrings - when the docs are built, this is what'll be rendered. I'm very happy for the maths within docstrings to then be moved into a notebook.
@thomaspinder for the commits d84ede9, e0b60d4 and bd6e4aa, I have addressed many of your comments and suggestions. Outstanding issues are to write unit tests, the notebook, and explanation of parameter optimisation order. |
Hi @thomaspinder, I believe I have addressed your comments so far. I expect all functions to be covered by unit tests now (but we'll see what CodeCov says). I would appreciate a review on the latest code changes - and I expect we can improve it further. One key thing that I dislike, is that the I am going to start work on writing the notebook today. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few small comments from me. When these are resolved and we're happy with the notebook, I think it's ready to merge.
gpjax/abstractions.py
Outdated
nat_grads_fn, hyper_grads_fn = natural_gradients( | ||
stochastic_vi, train_data, transformations | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These lines seem to be the only ones that require the stochastic_vi
object. Can we simply move this outside of the training loop and then just make the function accept two objective_fn; one for the hyperparams and one for the natgrads?
return params | ||
|
||
|
||
def natural_gradients( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So are we moving this or not?
bd8ce93
to
5a4cee8
Compare
All notebooks are updated, except the tensorflow probability and MCMC section of the classification notebook.
This reverts commit 7d9ed4d.
Fix negative, and add sketch optimisation loop.
This is rough. THERE IS A BUG SOMEWHERE. This does not use the training abstraction in natural_gradients.py
…onship in the notebook.
This commit updates variational families and their tests.
5a4cee8
to
9d92f5a
Compare
This commit finishes rebase issues, need to refactor code before merging to v0.5_update.
This PR seeks to add natural gradients to GPJax, as well as two new Gaussian variational family parameterisations.
Please check the type of change your PR introduces:
Current state:
Code currently only works for the natural parameterisation case. Main thing (asides from the obvious simplification of the rough codebase and improvement of the API) is that we lack natural gradients for general parameterisations.
There are two notebooks associated with this PR:
natgrads.ipynb
has the case the variational family is chosen as the natural parameterisation - this should work but I have not tested it since I rebased with the master branch.Natural Gradient General case.ipynb
is provided to show a rough sketch of what the general case might look like, that will involve the user defining a bijection between their parameterisation and the natural parameterisation.There are some unit tests provided for the new variational families, and for
natural_gradients.py
in its current state, as well as changes made inparameters.py
.As a final note, it is likely that some of the functions defined in
natural_gradients.py
for stopping gradients might be better generalised and added toabstractions.py
orparameters.py