Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stabilize covariance learning with FillScaleTriL and update config behaviour #163

Merged
merged 7 commits into from
Dec 21, 2022

Conversation

patel-zeel
Copy link
Contributor

@patel-zeel patel-zeel commented Dec 18, 2022

Pull request type

Please check the type of change your PR introduces:

  • Bugfix
  • Feature
  • Code style update (formatting, renaming)
  • Refactoring (no functional changes, no api changes)
  • Build related changes
  • Documentation content changes
  • Other (please describe):

What is the current v/s new behavior?

Issue Number: #127

As discussed in #127, the following is the summary of changes:

  1. Replace tfb.FillTriangular with tfb.FillScaleTriL in config.py
  2. Add the "Custom transformations" section in examples/uncollapsed_vi.pct.py to illustrate the use of custom transformation for lower triangular Cholesky parameters.
  3. Change the way config is created currently. Replace get_defaults() with get_global_config(). Add reset_global_config() and get_default_config() methods.
  4. Add the "Conversion between .ipynb and .py" section in README.md

"Custom transformations" section in examples/uncollapsed_vi.pct.py

I am showing how to use a custom triangular_transform in gpjax (the method also generalizes for any transform).

gpx_config = get_global_config()
transformations = gpx_config.transformations
jitter = gpx_config.jitter

triangular_transform = dx.Chain(
    [tfb.FillScaleTriL(diag_bijector=tfb.Square(), diag_shift=jnp.array(jitter))]
)

transformations.update({"triangular_transform": triangular_transform})

I have added a point that the Square bijector may lead to a faster convergence but can be unstable compared to Softplus bijector.

Softplus Square
image image

Config in GPJAX

How is it done currently?

  1. get_defaults method returns the global config. If the config is unavailable, it creates one.

What is the new behavior?

  1. get_global_config returns the global config. If the config is unavailable, it creates one. If JAX precision changes, it makes appropriate changes to the current global config e.g. update FillScaleTriL
  2. get_default_config returns the default config and does not update any global config.
  3. reset_global_config resets global config to default config.

Conversion between .ipynb and .py

The following quick commands are introduced in README.md to convert between .ipynb and .py.
image

Some changes are attempted in this PR without a detailed prior discussion. Please consider those as suggestions and suggest your thoughts on the required modifications.

What is not done?

I have not ensured yet that get_global_config does not get invoked at import gpjax as gpx (O1 in #127). Execution of

from .abstractions import fit, fit_batches, fit_natgrads
runs natural_gradients.py which invokes get_global_config. Any suggestions on how to solve this?

@codecov-commenter
Copy link

Codecov Report

Merging #163 (8166ae3) into master (c50d34d) will decrease coverage by 0.03%.
The diff coverage is 97.77%.

@@            Coverage Diff             @@
##           master     #163      +/-   ##
==========================================
- Coverage   96.89%   96.85%   -0.04%     
==========================================
  Files          15       15              
  Lines        1385     1400      +15     
==========================================
+ Hits         1342     1356      +14     
- Misses         43       44       +1     
Flag Coverage Δ
unittests 96.85% <97.77%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
gpjax/kernels.py 91.15% <ø> (-0.07%) ⬇️
gpjax/config.py 98.27% <95.83%> (-1.73%) ⬇️
gpjax/gps.py 100.00% <100.00%> (ø)
gpjax/natural_gradients.py 100.00% <100.00%> (ø)
gpjax/parameters.py 95.61% <100.00%> (ø)
gpjax/variational_families.py 100.00% <100.00%> (ø)
gpjax/variational_inference.py 97.77% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@daniel-dodd
Copy link
Member

Thanks, @patel-zeel! This PR looks great. Will start my review soon.

What is not done?

I have not ensured yet that get_global_config does not get invoked at import gpjax as gpx (O1 in #127). Execution of

from .abstractions import fit, fit_batches, fit_natgrads

runs natural_gradients.py which invokes get_global_config. Any suggestions on how to solve this?

Is get_global_config invoked in natural_gradients.py solely via DEFAULT_JITTER? We can then move the jitter inside the relevant functions/objects as done in gps.py and variational_families.py, to resolve this.

Copy link
Member

@daniel-dodd daniel-dodd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @patel-zeel, this is an excellent PR. Just a few minor things (please see my comments), and to see if O1 in #127 might be resolved via removing the default jitter and define jitter inside objects/functions akin to gps.py/variational_families.py. It would be nice to get O1 done here if we can! :)

README.md Show resolved Hide resolved
examples/uncollapsed_vi.pct.py Outdated Show resolved Hide resolved
examples/uncollapsed_vi.pct.py Outdated Show resolved Hide resolved
gpjax/config.py Show resolved Hide resolved
gpjax/config.py Outdated Show resolved Hide resolved
@patel-zeel
Copy link
Contributor Author

Thank you for the review, @daniel-dodd. I have resolved O1 in #127 by moving the jitter inside the function. I have also addressed the other comments.

@patel-zeel
Copy link
Contributor Author

patel-zeel commented Dec 19, 2022

@daniel-dodd since we have addressed O1 in #127, we may need to ensure it does not get violated in the future. I think it is hard to detect such things via unit testing (due to the shared global scope among all tests). However, I have added a relevant test (test_config_on_library_import) in test_config which calls get_global_config_if_exists function to check if the global config exists. This test seems to be passing only because of two reasons: i) test_config_on_library_import test is defined on top and thus it is executed first within test_config.py ii) All tests running before test_config.py do not invoke the creation of global config. What are your thoughts on this?

@daniel-dodd
Copy link
Member

Hi @patel-zeel, thank you for addressing my comments.

@daniel-dodd since we have addressed O1 in #127, we may need to ensure it does not get violated in the future.

In agreement with this. However, the latest test workflow failed on test_config_on_library_import (https://github.com/JaxGaussianProcesses/GPJax/actions/runs/3729043958/jobs/6330357412)? Happy to merge this PR with this test removed for now, and open a separate issue regarding checks for O1 violation?

@patel-zeel
Copy link
Contributor Author

Yes, that looks like a way to go :)

@daniel-dodd
Copy link
Member

@patel-zeel Nice! Running the test workflow now, will merge soon as the tests pass. :)

@daniel-dodd daniel-dodd merged commit 5cf7143 into JaxGaussianProcesses:master Dec 21, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants