Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backend doc #462

Merged
merged 9 commits into from
Aug 15, 2024
Merged
333 changes: 333 additions & 0 deletions docs/examples/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,333 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Backend Module Design
#
# Since v0.9, GPJax is built upon Flax's
# [NNX](https://flax.readthedocs.io/en/latest/nnx/index.html) module. This transition
# allows for more efficient parameter handling, improved integration with Flax and
# Flax-based libraries, and enhanced flexibility in model design. This notebook provides
# a high-level overview of the backend module design in GPJax. For an introduction to
# NNX, please refer to the [official
# documentation](https://flax.readthedocs.io/en/latest/nnx/index.html).
#

# %%
# Enable Float64 for more stable matrix inversions.
from jax import config, grad

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import matplotlib as mpl
import matplotlib.pyplot as plt
import gpjax as gpx

plt.style.use(
"https://github.com/raw/JaxGaussianProcesses/GPJax/main/docs/examples/gpjax.mplstyle"
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

# %% [markdown]
# ## Parameters
#
# The biggest change bought about by the transition to an NNX backend is the increased
# support we now provide for handling parameters. As discussed in our [Sharp Bits -
# Bijectors Doc](https://docs.jaxgaussianprocesses.com/sharp_bits/#bijectors), GPJax
# uses bijectors to transform constrained parameters to unconstrained parameters during
# optimisation. You may now register the support of a parameter using our `Parameter`
# class. To see this, consider the constant mean function who contains a single constant
# parameter whose value ordinarily exists on the real line. We can register this
# parameter as follows:

# %%
from gpjax.mean_functions import Constant
from gpjax.parameters import Real

constant_param = Real(value=1.0)
meanf = Constant(constant_param)
meanf

# %% [markdown]
# However, suppose you wish your mean function's constant parameter to be strictly
# positive. This is easy to achieve by using the correct Parameter type.

# %%
from gpjax.parameters import PositiveReal

constant_param = PositiveReal(value=1.0)
meanf = Constant(constant_param)
meanf

# %% [markdown]
# Were we to try and instantiate the `PositiveReal` class with a negative value, then an
# explicit error would be raised.

# %%
try:
PositiveReal(value=-1.0)
except ValueError as e:
print(e)

# %% [markdown]
# ### Parameter Transforms
#
# With a parameter instantiated, you likely wish to transform the parameter's value from
# its constrained support onto the entire real line. To do this, you can apply the
# `transform` function to the parameter. To control the bijector used to transform the
# parameter, you may pass a set of bijectors into the transform function.
# Under-the-hood, the `transform` function is looking up the bijector of a parameter
# using it's `_tag` field in the bijector dictionary, and then applying the bijector to
# the parameter's value using a tree map operation.

# %%
print(constant_param._tag)

# %% [markdown]
# For most users, you will not need to worry about this as we provide a set of default
# bijectors that are defined for all the parameter types we support. However, see our
# [Kernel Guide
# Notebook](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/) to
# see how you can define your own bijectors and parameter types.

# %%
from gpjax.parameters import DEFAULT_BIJECTION, transform

print(DEFAULT_BIJECTION[constant_param._tag])

# %% [markdown]
# We see here that the Softplus bijector is specified as the default for strictly
# positive parameters. To apply this, we may invoke the following

# %%
transform(constant_param, DEFAULT_BIJECTION, inverse=True)

# %% [markdown]
# The parameter's value was changed here from 1. to 0.54132485. This is the result of
# applying the Softplus bijector to the parameter's value and projecting its value onto
# the real line. Were the parameter's value to be closer to 0, then the transformation
# would be more pronounced.

# %%
transform(PositiveReal(value=1e-6), DEFAULT_BIJECTION, inverse=True)

# %% [markdown]
# ### Transforming Multiple Parameters
#
# In the above, we transformed a single parameter. However, in practice your parameters
# may be nested within several functions e.g., a kernel function within a GP model.
# Fortunately, transforming several parameters is a simple operation that we here
# demonstrate for a conjugate GP posterior (see our [Regression
# Notebook](https://docs.jaxgaussianprocesses.com/examples/regression/) for detailed
# explanation of this model.).

# %%
kernel = gpx.kernels.Matern32()
meanf = gpx.mean_functions.Constant()

prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

likelihood = gpx.likelihoods.Gaussian(100)
posterior = likelihood * prior
posterior

# %% [markdown]
# Now contained within the posterior PyGraph here there are four parameters: the
# kernel's lengthscale and variance, the noise variance of the likelihood, and the
# constant of the mean function. Using NNX, we may realise these parameters through the
# `nnx.split` function. The `split` function deomposes a PyGraph into a `GraphDef` and
# `State` object. As the name suggests, `State` contains information on the parameters'
# state, whilst `GraphDef` contains the information required to reconstruct a PyGraph
# from a give `State`.

# %%
from flax import nnx

graphdef, state = nnx.split(posterior)
state

# %% [markdown]
# The `State` object behaves just like a PyTree and, consequently, we may use JAX's
# `tree_map` function to alter the values of the `State`. The updated `State` can then
# be used to reconstruct our posterior. In the below, we simply increment each
# parameter's value by 1.

# %%
import jax.tree_util as jtu

updated_state = jtu.tree_map(lambda x: x + 1, state)
updated_state

# %% [markdown]
# Let us now use NNX's `merge` function to reconstruct the posterior distribution using
# the updated state.

# %%
updated_posterior = nnx.merge(graphdef, updated_state)
updated_posterior

# %% [markdown]
# However, we begun this point of conversation with bijectors in mind, so let us now see
# how bijectors may be applied to a collection of parameters in GPJax. Fortunately, this
# is very straightforward, and we may simply use the `trasnform` function as before.
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

# %%
transformed_state = transform(state, DEFAULT_BIJECTION, inverse=True)
transformed_state

# %% [markdown]
# We may also (re-)constrain the parameters' values by setting the `inverse` argument of
# `transform` to False.

# %%
retransformed_state = transform(transformed_state, DEFAULT_BIJECTION, inverse=False)
retransformed_state == transformed_state
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved

# %% [markdown]
# ### Fine-Scale Control
#
# One of the advantages of being able to split and re-merge the PyGraph is that we are
# able to gain fine-scale control over the parameters' whose state we wish to realise.
# This is by virtue of the fact that each of our parameters now inherit from
# `gpjax.parameters.Parameter`. In the former, we were simply extracting and `Parameter`
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved
# from the posterior. However, suppose we only wish to extract those parameters whose
# support is the positive real line. This is easily achieved by altering the way in
# which we invoke `nnx.split`.

# %%
from gpjax.parameters import PositiveReal

graphdef, positive_reals, other_params = nnx.split(posterior, PositiveReal, ...)
print(positive_reals)

# %% [markdown]
# Now we see that we have two state objects: one containing the positive real parameters
# and the other containing the remaining parameters. This functionality is exceptionally
# useful as it allows us to efficiently operate on a subset of the parameters whilst
# leaving the others untouched. Looking forward, we hope to use this functionality in
# our [Variational Inference
# Approximations](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/) to
# perform more efficient updates of the variational parameters and then the model's
# hyperparameters.

# %% [markdown]
# ## NNX Modules
#
# To conclude this notebook, we will now demonstrate the ease of use and flexibility
# offered by NNX modules. To do this, we will implement a linear mean function using the
# existing abstractions in GPJax.
#
# For inputs $x_n \in \mathbb{R}^d$, the linear mean function $m(x): \mathbb{R}^d \to
# \mathbb{R}$ is defined as:
# $$
# m(x) = \alpha + \sum_{i=1}^d \beta_i x_i
# $$
# where $\alpha \in \mathbb{R}$ and $\beta_i \in \mathbb{R}$ are the parameters of the
# mean function. Let's now implement that using the new NNX backend.

# %%
import typing as tp

from jaxtyping import Float, Num

from gpjax.mean_functions import AbstractMeanFunction
from gpjax.parameters import Parameter, Real
from gpjax.typing import ScalarFloat, Array


class LinearMeanFunction(AbstractMeanFunction):
def __init__(
self,
intercept: tp.Union[ScalarFloat, Float[Array, " O"], Parameter] = 0.0,
slope: tp.Union[ScalarFloat, Float[Array, " D O"], Parameter] = 0.0,
):
if isinstance(intercept, Parameter):
self.intercept = intercept
else:
self.intercept = Real(jnp.array(intercept))

if isinstance(slope, Parameter):
self.slope = slope
else:
self.slope = Real(jnp.array(slope))

def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N O"]:
return self.intercept.value + jnp.dot(x, self.slope.value)


# %% [markdown]
# As we can see, the implementation is straightforward and concise. The
# `AbstractMeanFunction` module is a subclass of `nnx.Module` and may, therefore, be
# used in any `split` or `merge` call. Further, we have registered the intercept and
# slope parameters as `Real` parameter types. This registers their value in the PyGraph
# and means that they will be part of any operation applied to the PyGraph e.g.,
# transforming and differentiation.
#
# To check our implementation worked, let's now plot the value of our mean function for
# a linearly spaced set of inputs.

# %%
N = 100
X = jnp.linspace(-5.0, 5.0, N)[:, None]

meanf = LinearMeanFunction(intercept=1.0, slope=2.0)
plt.plot(X, meanf(X))

# %% [markdown]
# Looks good! To conclude this section, let's now parameterise a GP with our new mean
# function and see how gradients may be computed.

# %%
y = jnp.sin(X)
D = gpx.Dataset(X, y)

prior = gpx.gps.Prior(mean_function=meanf, kernel=gpx.kernels.Matern32())
likelihood = gpx.likelihoods.Gaussian(D.n)
posterior = likelihood * prior

# %% [markdown]
# We'll compute derivatives of the conjugate marginal log-likelihood, with respect to
# the unconstrained state of the kernel, mean function, and likelihood parameters.

# %%
graphdef, params, others = nnx.split(posterior, Parameter, ...)
params = transform(params, DEFAULT_BIJECTION)
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved


def loss_fn(params: nnx.State, data: gpx.Dataset) -> ScalarFloat:
params = transform(params, DEFAULT_BIJECTION)
model = nnx.merge(graphdef, params, *others)
return -gpx.objectives.conjugate_mll(model, data)


grad(loss_fn)(params, D)

# %% [markdown]
# ## Conclusions
#
# In this notebook we have explored how GPJax's Flax-based backend may be easily
# manipulated and extended. For a more applied look at this, see how we construct a
# kernel on polar coordinated in our [Kernel
thomaspinder marked this conversation as resolved.
Show resolved Hide resolved
# Guide](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
# notebook.
#
# ## System configuration

# %%
# %reload_ext watermark
# %watermark -n -u -v -iv -w -a 'Thomas Pinder'
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ nav:
- 📖 Guides for customisation:
- Kernels: examples/constructing_new_kernels.md
- Likelihoods: examples/likelihoods_guide.md
- Model Guide: examples/backend.md
- UCI regression: examples/yacht.md
- 💻 Raw tutorial code: give_me_the_code.md
- Community:
Expand Down
Loading