Skip to content

v0.9.0

Compare
Choose a tag to compare
@thomaspinder thomaspinder released this 16 Aug 19:53
· 13 commits to main since this release
9ba68a4

Key Changes

In v0.9.0, the backend of GPJax has migrated to the NNX module of Flax (h/t @cgarciae) . This allows us to assign structure to our parameters, hook into Flax's ecosystem, and simplify our underlying code. From a frontend perspective, you may now assign a type to your parameters e.g., PositiveReal and easily invoke the state of your GP.

State

For any component of GPJax, such as a kernel, mean function, prior or posterior GP.etc, you may now realise the state of the component using nnx.Split. For example, the state of a Matérn kernel would be realised by

kernel = gpx.kernels.Matern32()
_, params = nnx.split(kernel, gpx.parameters.Parameter)

This allows users to have low-level control over the exact operations that are made on the parameters. We detail this fully in the Parameters section of our Model Guide notebook.

Parameters

We now recognise the support that a parameter. For example, strictly positive parameters such as the lengthscale or variance are instantiated via the PositiveReal parameter. Meanwhile, parameters that are constrained to be lower triangular matrix are instantiated through the LowerTriangular parameter.

What's Changed

New Contributors

Full Changelog: v0.8.2...v0.9.0