Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Natgrads #90

Merged
merged 67 commits into from
Sep 20, 2022
Merged

Natgrads #90

merged 67 commits into from
Sep 20, 2022

Conversation

daniel-dodd
Copy link
Member

@daniel-dodd daniel-dodd commented Jul 21, 2022

This PR seeks to add natural gradients to GPJax, as well as two new Gaussian variational family parameterisations.

Please check the type of change your PR introduces:

  • Bugfix
  • [ x] Feature
  • Code style update (formatting, renaming)
  • [x ] Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

Current state:

Code currently only works for the natural parameterisation case. Main thing (asides from the obvious simplification of the rough codebase and improvement of the API) is that we lack natural gradients for general parameterisations.

There are two notebooks associated with this PR:

  • natgrads.ipynb has the case the variational family is chosen as the natural parameterisation - this should work but I have not tested it since I rebased with the master branch.
  • Natural Gradient General case.ipynb is provided to show a rough sketch of what the general case might look like, that will involve the user defining a bijection between their parameterisation and the natural parameterisation.

There are some unit tests provided for the new variational families, and for natural_gradients.py in its current state, as well as changes made in parameters.py.

As a final note, it is likely that some of the functions defined in natural_gradients.py for stopping gradients might be better generalised and added to abstractions.py or parameters.py

@daniel-dodd
Copy link
Member Author

daniel-dodd commented Aug 19, 2022

@thomaspinder it would be good to get your thoughts on the code so far (some of it rather rough). I believe the implementation works. The main issue is that the notebook needs writing before merging, I'll get on with this. Also some tests need writing. I'll need to test some benchmarks against fit_batches once #99 is merged, to see if performance is up to scratch.

@codecov
Copy link

codecov bot commented Aug 19, 2022

Codecov Report

❗ No coverage uploaded for pull request base (v0.5_update@7773eef). Click here to learn what that means.
The diff coverage is n/a.

❗ Current head cc1c318 differs from pull request most recent head 24032c0. Consider uploading reports for the commit 24032c0 to get more accurate results

@@              Coverage Diff               @@
##             v0.5_update      #90   +/-   ##
==============================================
  Coverage               ?   99.22%           
==============================================
  Files                  ?       14           
  Lines                  ?     1154           
  Branches               ?        0           
==============================================
  Hits                   ?     1145           
  Misses                 ?        9           
  Partials               ?        0           
Flag Coverage Δ
unittests 99.22% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Started a review but will finish more tomorrow.

gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
@daniel-dodd
Copy link
Member Author

@thomaspinder I have rebased the branch with master and would appreciate a review.

On my end, I need to:

  • Update and add some tests
  • Write the notebook

docs/nbs/natgrads.ipynb Outdated Show resolved Hide resolved
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Several comments here, many of them suggestions, not hard requires. A general comment is that more information is needed in the docstrings - when the docs are built, this is what'll be rendered. I'm very happy for the maths within docstrings to then be moved into a notebook.

gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/natural_gradients.py Outdated Show resolved Hide resolved
gpjax/parameters.py Outdated Show resolved Hide resolved
gpjax/parameters.py Outdated Show resolved Hide resolved
gpjax/variational_families.py Outdated Show resolved Hide resolved
gpjax/variational_families.py Show resolved Hide resolved
@daniel-dodd
Copy link
Member Author

@thomaspinder for the commits d84ede9, e0b60d4 and bd6e4aa, I have addressed many of your comments and suggestions.

Outstanding issues are to write unit tests, the notebook, and explanation of parameter optimisation order.

@daniel-dodd
Copy link
Member Author

daniel-dodd commented Aug 24, 2022

Hi @thomaspinder, I believe I have addressed your comments so far. I expect all functions to be covered by unit tests now (but we'll see what CodeCov says). I would appreciate a review on the latest code changes - and I expect we can improve it further.

One key thing that I dislike, is that the fit_natgrads abstraction takes it fist argument as the variational inference strategy, while fit and fit_batches take in objectives (I find this inconsistent). I'm not sure if you had any thoughts on this?

I am going to start work on writing the notebook today.

Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few small comments from me. When these are resolved and we're happy with the notebook, I think it's ready to merge.

Comment on lines 231 to 248
nat_grads_fn, hyper_grads_fn = natural_gradients(
stochastic_vi, train_data, transformations
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines seem to be the only ones that require the stochastic_vi object. Can we simply move this outside of the training loop and then just make the function accept two objective_fn; one for the hyperparams and one for the natgrads?

gpjax/abstractions.py Outdated Show resolved Hide resolved
return params


def natural_gradients(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are we moving this or not?

gpjax/parameters.py Outdated Show resolved Hide resolved
gpjax/parameters.py Show resolved Hide resolved
This commit finishes rebase issues, need to refactor code before merging to v0.5_update.
@daniel-dodd daniel-dodd changed the base branch from master to v0.5_update September 20, 2022 13:51
@thomaspinder thomaspinder marked this pull request as ready for review September 20, 2022 20:10
@thomaspinder thomaspinder merged commit 14010fc into v0.5_update Sep 20, 2022
@daniel-dodd daniel-dodd deleted the natgrads branch October 17, 2022 12:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants