v0.9.0
Key Changes
In v0.9.0, the backend of GPJax has migrated to the NNX module of Flax (h/t @cgarciae) . This allows us to assign structure to our parameters, hook into Flax's ecosystem, and simplify our underlying code. From a frontend perspective, you may now assign a type to your parameters e.g., PositiveReal
and easily invoke the state of your GP.
State
For any component of GPJax, such as a kernel, mean function, prior or posterior GP.etc, you may now realise the state of the component using nnx.Split
. For example, the state of a Matérn kernel would be realised by
kernel = gpx.kernels.Matern32()
_, params = nnx.split(kernel, gpx.parameters.Parameter)
This allows users to have low-level control over the exact operations that are made on the parameters. We detail this fully in the Parameters section of our Model Guide notebook.
Parameters
We now recognise the support that a parameter. For example, strictly positive parameters such as the lengthscale or variance are instantiated via the PositiveReal
parameter. Meanwhile, parameters that are constrained to be lower triangular matrix are instantiated through the LowerTriangular
parameter.
What's Changed
- Point CoC breaches to our contact form. by @thomaspinder in #451
- Quick fix to stop automatic switch to CG by @henrymoss in #454
- Fix some type annotations causing failing tests by @stephen-huan in #456
- Feature: Adds probability of improvement as an acquisition function by @miguelgondu in #458
- Add expected improvement utility function by @Thomas-Christie in #460
- feat(gpjax/kernels/base.py): add diagonal by @stephen-huan in #429
- Flax/nnx backend by @frazane in #440
New Contributors
- @stephen-huan made their first contribution in #456
- @miguelgondu made their first contribution in #458
Full Changelog: v0.8.2...v0.9.0