Skip to content

Commit

Permalink
Merge pull request #148 from JaxGaussianProcesses/More_Kernels
Browse files Browse the repository at this point in the history
This PR add additional kernels, and provides the notion of a "compute_engine" to perform kernel operations, that in future will build the foundation for alternative matrix solving algorithms, e.g., conjugate gradients.
  • Loading branch information
daniel-dodd authored Dec 13, 2022
2 parents ebd6cb7 + e5c339a commit c50d34d
Show file tree
Hide file tree
Showing 19 changed files with 897 additions and 466 deletions.
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ repos:
- id: nbqa-pyupgrade
args: [--py37-plus]
- id: nbqa-flake8
args: ['--ignore=E501,E203,E302,E402,E731,W503']
args: ['--ignore=E501,E203,E302,E402,E731,W503']
- repo: https://github.com/PyCQA/autoflake
rev: v2.0.0
hooks:
- id: autoflake
args: ["--in-place", "--remove-unused-variables", "--remove-all-unused-imports", "--recursive"]
name: AutoFlake
description: "Format with AutoFlake"
stages: [commit]
48 changes: 24 additions & 24 deletions examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: base
# display_name: Python 3.9.7 ('gpjax')
# language: python
# name: python3
# ---
Expand All @@ -19,7 +19,7 @@
#
# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling.

# %% vscode={"languageId": "python"}
# %%
import blackjax
import distrax as dx
import jax
Expand Down Expand Up @@ -47,7 +47,7 @@
#
# We store our data $\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later.

# %% vscode={"languageId": "python"}
# %%
x = jnp.sort(jr.uniform(key, shape=(100, 1), minval=-1.0, maxval=1.0), axis=0)
y = 0.5 * jnp.sign(jnp.cos(3 * x + jr.normal(key, shape=x.shape) * 0.05)) + 0.5

Expand All @@ -61,15 +61,15 @@
#
# We begin by defining a Gaussian process prior with a radial basis function (RBF) kernel, chosen for the purpose of exposition. Since our observations are binary, we choose a Bernoulli likelihood with a probit link function.

# %% vscode={"languageId": "python"}
# %%
kernel = gpx.RBF()
prior = gpx.Prior(kernel=kernel)
likelihood = gpx.Bernoulli(num_datapoints=D.n)

# %% [markdown]
# We construct the posterior through the product of our prior and likelihood.

# %% vscode={"languageId": "python"}
# %%
posterior = prior * likelihood
print(type(posterior))

Expand All @@ -79,7 +79,7 @@
# %% [markdown]
# To begin we obtain an initial parameter state through the `initialise` callable (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We can obtain a MAP estimate by optimising the marginal log-likelihood with Optax's optimisers.

# %% vscode={"languageId": "python"}
# %%
parameter_state = gpx.initialise(posterior)
negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))

Expand All @@ -97,7 +97,7 @@
# %% [markdown]
# From which we can make predictions at novel inputs, as illustrated below.

# %% vscode={"languageId": "python"}
# %%
map_latent_dist = posterior(map_estimate, D)(xtest)

predictive_dist = likelihood(map_estimate, map_latent_dist)
Expand Down Expand Up @@ -158,15 +158,15 @@
#
# that we identify as a Gaussian distribution, $p(\boldsymbol{f}| \mathcal{D}) \approx q(\boldsymbol{f}) := \mathcal{N}(\hat{\boldsymbol{f}}, [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.

# %% vscode={"languageId": "python"}
# %%
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxx = gram(map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()
f_hat = jnp.matmul(Lx, map_estimate["latent"])
Lx = Kxx.to_root()
f_hat = Lx @ map_estimate["latent"]

# Negative Hessian, H = -∇²p_tilde(y|f):
H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)["latent"]["latent"][:, 0, :, 0]
Expand All @@ -190,21 +190,21 @@
#
# This is the same approximate distribution $q_{map}(f(\cdot))$, but we have pertubed the covariance by a curvature term of $\mathbf{K}_{\boldsymbol{(\cdot)\boldsymbol{x}}} \mathbf{K}_{\boldsymbol{xx}}^{-1} [-\nabla^2 \tilde{p}(\boldsymbol{y}|\boldsymbol{f})|_{\hat{\boldsymbol{f}}} ]^{-1} \mathbf{K}_{\boldsymbol{xx}}^{-1} \mathbf{K}_{\boldsymbol{\boldsymbol{x}(\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\cdot))$.

# %% vscode={"languageId": "python"}
# %%
def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri:

map_latent_dist = posterior(map_estimate, D)(test_inputs)

Kxt = cross_covariance(kernel, map_estimate["kernel"], x, test_inputs)
Kxx = gram(kernel, map_estimate["kernel"], x)
Kxt = cross_covariance(map_estimate["kernel"], x, test_inputs)
Kxx = gram(map_estimate["kernel"], x)
Kxx += I(D.n) * jitter
Lx = Kxx.triangular_lower()
Lx = Kxx.to_root()

# Lx⁻¹ Kxt
Lx_inv_Ktx = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)
Lx_inv_Ktx = Lx.solve(Kxt)

# Kxx⁻¹ Kxt
Kxx_inv_Ktx = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Ktx, lower=False)
Kxx_inv_Ktx = Lx.T.solve(Lx_inv_Ktx)

# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)
Expand All @@ -217,7 +217,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal

# %% [markdown]
# From this we can construct the predictive distribution at the test points.
# %% vscode={"languageId": "python"}
# %%
laplace_latent_dist = construct_laplace(xtest)
predictive_dist = likelihood(map_estimate, laplace_latent_dist)

Expand Down Expand Up @@ -267,7 +267,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormal
#
# We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary.

# %% vscode={"languageId": "python"}
# %%
# Adapted from BlackJax's introduction notebook.
num_adapt = 500
num_samples = 500
Expand Down Expand Up @@ -304,14 +304,14 @@ def one_step(state, rng_key):
#
# BlackJax gives us easy access to our sampler's efficiency through metrics such as the sampler's _acceptance probability_ (the number of times that our chain accepted a proposed sample, divided by the total number of steps run by the chain). For NUTS and Hamiltonian Monte Carlo sampling, we typically seek an acceptance rate of 60-70% to strike the right balance between having a chain which is _stuck_ and rarely moves versus a chain that is too jumpy with frequent small steps.

# %% vscode={"languageId": "python"}
# %%
acceptance_rate = jnp.mean(infos.acceptance_probability)
print(f"Acceptance rate: {acceptance_rate:.2f}")

# %% [markdown]
# Our acceptance rate is slightly too large, prompting an examination of the chain's trace plots. A well-mixing chain will have very few (if any) flat spots in its trace plot whilst also not having too many steps in the same direction. In addition to the model's hyperparameters, there will be 500 samples for each of the 100 latent function values in the `states.position` dictionary. We depict the chains that correspond to the model hyperparameters and the first value of the latent function for brevity.

# %% vscode={"languageId": "python"}
# %%
fig, (ax0, ax1, ax2) = plt.subplots(ncols=3, figsize=(15, 5), tight_layout=True)
ax0.plot(states.position["kernel"]["lengthscale"])
ax1.plot(states.position["kernel"]["variance"])
Expand All @@ -327,7 +327,7 @@ def one_step(state, rng_key):
#
# An ideal Markov chain would have samples completely uncorrelated with their neighbours after a single lag. However, in practice, correlations often exist within our chain's sample set. A commonly used technique to try and reduce this correlation is _thinning_ whereby we select every $n$th sample where $n$ is the minimum lag length at which we believe the samples are uncorrelated. Although further analysis of the chain's autocorrelation is required to find appropriate thinning factors, we employ a thin factor of 10 for demonstration purposes.

# %% vscode={"languageId": "python"}
# %%
thin_factor = 10
samples = []

Expand All @@ -351,7 +351,7 @@ def one_step(state, rng_key):
#
# Finally, we end this tutorial by plotting the predictions obtained from our model against the observed data.

# %% vscode={"languageId": "python"}
# %%
fig, ax = plt.subplots(figsize=(16, 5), tight_layout=True)
ax.plot(
x, y, "o", markersize=5, color="tab:red", label="Observations", zorder=2, alpha=0.7
Expand All @@ -371,6 +371,6 @@ def one_step(state, rng_key):
# %% [markdown]
# ## System configuration

# %% vscode={"languageId": "python"}
# %%
# %load_ext watermark
# %watermark -n -u -v -iv -w -a "Thomas Pinder & Daniel Dodd"
10 changes: 8 additions & 2 deletions examples/graph_kernels.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: base
# display_name: Python 3.9.7 ('gpjax')
# language: python
# name: python3
# ---
Expand Down Expand Up @@ -55,7 +55,7 @@

pos = nx.spring_layout(G, seed=123) # positions for all nodes

nx.draw(G, pos, node_color="tab:blue", with_labels=False, alpha=0.5)
nx.draw(G) # , pos, node_color="tab:blue", with_labels=False, alpha=0.5)

# %% [markdown]
#
Expand Down Expand Up @@ -95,6 +95,12 @@

D = gpx.Dataset(X=x, y=y)

# %%
kernel.compute_engine.gram

# %%
kernel.gram(params=kernel._initialise_params(key), inputs=x)

# %% [markdown]
#
# We can visualise this signal in the following cell.
Expand Down
32 changes: 21 additions & 11 deletions examples/haiku.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: base
# display_name: Python 3.9.7 ('gpjax')
# language: python
# name: python3
# ---
Expand All @@ -28,15 +28,18 @@
import jax.random as jr
import matplotlib.pyplot as plt
import optax as ox
from chex import dataclass
from jax.config import config
from scipy.signal import sawtooth
from jaxtyping import Float, Array
from typing import Dict


import gpjax as gpx
from gpjax.kernels import DenseKernelComputation, AbstractKernel
from gpjax.kernels import (
DenseKernelComputation,
AbstractKernelComputation,
AbstractKernel,
)
from gpjax.types import PRNGKeyType

# Enable Float64 for more stable matrix inversions.
Expand Down Expand Up @@ -79,16 +82,23 @@
# Although deep kernels are not currently supported natively in GPJax, defining one is straightforward as we now demonstrate. Using the base `AbstractKernel` object given in GPJax, we provide a mixin class named `_DeepKernelFunction` to facilitate the user supplying the neural network and base kernel of their choice. Kernel matrices are then computed using the regular `gram` and `cross_covariance` functions.

# %%
@dataclass
class _DeepKernelFunction:
network: hk.Module
base_kernel: AbstractKernel

class DeepKernelFunction(AbstractKernel):
def __init__(
self,
network: hk.Module,
base_kernel: AbstractKernel,
compute_engine: AbstractKernelComputation = DenseKernelComputation,
active_dims: tp.Optional[tp.List[int]] = None,
) -> None:
super().__init__(compute_engine, active_dims, True, False, "Deep Kernel")
self.network = network
self.base_kernel = base_kernel

@dataclass
class DeepKernelFunction(AbstractKernel, DenseKernelComputation, _DeepKernelFunction):
def __call__(
self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"],
self,
params: Dict,
x: Float[Array, "1 D"],
y: Float[Array, "1 D"],
) -> Float[Array, "1"]:
xt = self.network.apply(params=params, x=x)
yt = self.network.apply(params=params, x=y)
Expand Down
Loading

0 comments on commit c50d34d

Please sign in to comment.