Skip to content

Commit

Permalink
Resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Mar 23, 2023
2 parents 6ccb2c7 + 8edbb75 commit f427123
Show file tree
Hide file tree
Showing 47 changed files with 1,260 additions and 5,566 deletions.
5 changes: 0 additions & 5 deletions .mailmap

This file was deleted.

45 changes: 19 additions & 26 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,23 @@
repos:
- repo: https://github.com/psf/black
rev: 22.3.0
hooks:
- id: black
# - repo: https://github.com/pycqa/isort
# rev: 5.10.1
# - repo: https://github.com/charliermarsh/ruff-pre-commit
# rev: 'v0.0.254'
# hooks:
# - id: isort
# args: ["--profile", "black"]
- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
hooks:
- id: nbstripout
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.3.1
# - id: ruff
# args: ['--fix']
- repo: local
hooks:
- id: nbqa-black
- id: nbqa-pyupgrade
args: [--py37-plus]
- id: nbqa-flake8
args: ['--ignore=E501,E203,E302,E402,E731,W503']
- repo: https://github.com/PyCQA/autoflake
rev: v2.0.0
- id: darglint
name: darglint
entry: darglint
language: system
types: [python]
stages: [manual]
# - repo: https://github.com/PyCQA/docformatter
# rev: v1.5.0
# hooks:
# - id: docformatter
# args: [--in-place --config ./pyproject.toml]
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: autoflake
args: ["--in-place", "--remove-unused-variables", "--remove-all-unused-imports", "--recursive"]
name: AutoFlake
description: "Format with AutoFlake"
stages: [commit]
- id: black
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Feel free to join our [Slack Channel](https://join.slack.com/t/gpjax/shared_invi
> - [**Inference on Graphs**](https://gpjax.readthedocs.io/en/latest/examples/graph_kernels.html)
> - [**Learning Gaussian Process Barycentres**](https://gpjax.readthedocs.io/en/latest/examples/barycentres.html)
> - [**Deep Kernel Regression**](https://gpjax.readthedocs.io/en/latest/examples/haiku.html)
> - [**Natural Gradients**](https://gpjax.readthedocs.io/en/latest/examples/natgrads.html)
<!-- > - [**Natural Gradients**](https://gpjax.readthedocs.io/en/latest/examples/natgrads.html) -->
## Guides for customisation
>
Expand Down
3 changes: 2 additions & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ ignore:
- "tests/*.py"
- "**/_version.py"
- "**/versioneer.py"
- "versioneer.py"
- "versioneer.py"
- "gpjax/_version.py"
115 changes: 91 additions & 24 deletions examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
# %% [markdown]
# # Gaussian Processes Barycentres
#
# In this notebook we'll give an implementation of <strong data-cite="mallasto2017learning"></strong>. In this work, the existence of a Wasserstein barycentre between a collection of Gaussian processes is proven. When faced with trying to _average_ a set of probability distributions, the Wasserstein barycentre is an attractive choice as it enables uncertainty amongst the individual distributions to be incorporated into the averaged distribution. When compared to a naive _mean of means_ and _mean of variances_ approach to computing the average probability distributions, it can be seen that Wasserstein barycentres offer significantly more favourable uncertainty estimation.
# In this notebook we'll give an implementation of
# <strong data-cite="mallasto2017learning"></strong>. In this work, the existence of a
# Wasserstein barycentre between a collection of Gaussian processes is proven. When
# faced with trying to _average_ a set of probability distributions, the Wasserstein
# barycentre is an attractive choice as it enables uncertainty amongst the individual
# distributions to be incorporated into the averaged distribution. When compared to a
# naive _mean of means_ and _mean of variances_ approach to computing the average
# probability distributions, it can be seen that Wasserstein barycentres offer
# significantly more favourable uncertainty estimation.
#

# %%
Expand All @@ -31,7 +39,7 @@
import matplotlib.pyplot as plt
import optax as ox
from jax.config import config
from jaxutils import Dataset
from jaxutils import Dataset, fit
import jaxkern as jk

import gpjax as gpx
Expand All @@ -45,29 +53,58 @@
#
# ### Wasserstein distance
#
# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$ quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$, or vice-versa. Typically, computing this metric requires solving a linear program. However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian distributions, the solution is analytically given by
# $$W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 + \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),$$
# The 2-Wasserstein distance metric between two probability measures $\mu$ and $\nu$
# quantifies the minimal cost required to transport the unit mass from $\mu$ to $\nu$,
# or vice-versa. Typically, computing this metric requires solving a linear program.
# However, when $\mu$ and $\nu$ both belong to the family of multivariate Gaussian
# distributions, the solution is analytically given by
# $$W_2^2(\mu, \nu) = \lVert m_1- m_2 \rVert^2_2 +
# \operatorname{Tr}(S_1 + S_2 - 2(S_1^{1/2}S_2S_1^{1/2})^{1/2}),$$
# where $\mu \sim \mathcal{N}(m_1, S_1)$ and $\nu\sim\mathcal{N}(m_2, S_2)$.
#
# ### Wasserstein barycentre
#
# For a collection of $T$ measures $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all other measures in the set. More formally, the Wasserstein barycentre is the Fréchet mean on a Wasserstein space that we can write as
# $$\bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}\sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),$$
# For a collection of $T$ measures
# $\lbrace\mu_i\rbrace_{t=1}^T \in \mathcal{P}_2(\theta)$, the Wasserstein barycentre
# $\bar{\mu}$ is the measure that minimises the average Wasserstein distance to all
# other measures in the set. More formally, the Wasserstein barycentre is the Fréchet
# mean on a Wasserstein space that we can write as
# $$
# \bar{\mu} = \operatorname{argmin}_{\mu\in\mathcal{P}_2(\theta)}
# \sum_{t=1}^T \alpha_t W_2^2(\mu, \mu_t),$$
# where $\alpha\in\bbR^T$ is a weight vector that sums to 1.
#
# As with the Wasserstein distance, identifying the Wasserstein barycentre $\bar{\mu}$ is often an computationally demanding optimisation problem. However, when all the measures admit a multivariate Gaussian density, the barycentre $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical solutions
# $$\bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)$$
# As with the Wasserstein distance, identifying the Wasserstein barycentre
# $\bar{\mu}$ is often an computationally demanding optimisation problem.
# However, when all the measures admit a multivariate Gaussian density,
# the barycentre $\bar{\mu} = \mathcal{N}(\bar{m}, \bar{S})$ has analytical
# solutions
# $$\bar{m} = \sum_{t=1}^T \alpha_t m_t\,, \quad \bar{S}=\sum_{t=1}^T\alpha_t
# (\bar{S}^{1/2}S_t\bar{S}^{1/2})^{1/2}\,. \qquad (\star)$$
# Identifying $\bar{S}$ is achieved through a fixed-point iterative update.
#
# ## Barycentre of Gaussian processes
#
# It was shown in <strong data-cite="mallasto2017learning"></strong> that the barycentre $\bar{f}$ of a collection of Gaussian processes $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$ can be found using the same solutions as in $(\star)$. For a full theoretical understanding, we recommend reading the original paper. However, the central argument to this result is that one can first show that the barycentre GP $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$ converges to the Wasserstein metric between GPs as $n\to\infty$.
# It was shown in <strong data-cite="mallasto2017learning"></strong> that the
# barycentre $\bar{f}$ of a collection of Gaussian processes
# $\lbrace f_i\rbrace_{i=1}^T$ such that $f_i \sim \mathcal{GP}(m_i, K_i)$
# can be found using the same solutions as in $(\star)$. For a full theoretical
# understanding, we recommend reading the original paper. However, the central argument
# to this result is that one can first show that the barycentre GP
# $\bar{f}\sim\mathcal{GP}(\bar{m}, \bar{S})$ is non-degenerate for any finite set of
# GPs $\lbrace f_t\rbrace_{t=1}^T$ i.e., $T<\infty$. With this established, one can
# show that for a $n$-dimensional finite Gaussian distribution $f_{i,n}$, the
# Wasserstein metric between any two Gaussian distributions $f_{i, n}, f_{j, n}$
# converges to the Wasserstein metric between GPs as $n\to\infty$.
#
# In this notebook, we will demonstrate how this can be achieved in GPJax.
#
# ## Dataset
#
# We'll simulate five datasets and develop a Gaussian process posterior before identifying the Gaussian process barycentre at a set of test points. Each dataset will be a sine function with a different vertical shift, periodicity, and quantity of noise.
# We'll simulate five datasets and develop a Gaussian process posterior before
# identifying the Gaussian process barycentre at a set of test points. Each dataset
# will be a sine function with a different vertical shift, periodicity, and
# quantity of noise.

# %%
n = 100
Expand All @@ -76,7 +113,11 @@

x = jnp.linspace(-5.0, 5.0, n).reshape(-1, 1)
xtest = jnp.linspace(-5.5, 5.5, n_test).reshape(-1, 1)
f = lambda x, a, b: a + jnp.sin(b * x)


def f(x, a, b):
return a + jnp.sin(b * x)


ys = []
for i in range(n_datasets):
Expand All @@ -96,7 +137,14 @@
# %% [markdown]
# ## Learning a posterior distribution
#
# We'll now independently learn Gaussian process posterior distributions for each dataset. We won't spend any time here discussing how GP hyperparameters are optimised. For advice on achieving this, see the [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html) for advice on optimisation and the [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for advice on selecting an appropriate kernel.
# We'll now independently learn Gaussian process posterior distributions for each
# dataset. We won't spend any time here discussing how GP hyperparameters are
# optimised. For advice on achieving this, see the
# [Regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)
# for advice on optimisation and the
# [Kernels notebook](https://gpjax.readthedocs.io/en/latest/nbs/kernels.html) for
# advice on selecting an appropriate kernel.


# %%
def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
Expand All @@ -107,18 +155,16 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(kernel=jk.RBF()) * likelihood

parameter_state = gpx.initialise(posterior, key)
negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))
optimiser = ox.adam(learning_rate=0.01)
objective = gpx.ConjugateMLL(posterior=posterior, negative=True)

inference_state = gpx.fit(
objective=negative_mll,
parameter_state=parameter_state,
optax_optim=optimiser,
num_iters=1000,
learned_params = fit(
objective=objective,
train_data=D,
optim=optimiser,
num_iters=500,
)

learned_params, training_history = inference_state.unpack()
return likelihood(learned_params, posterior(learned_params, D)(xtest))


Expand All @@ -127,7 +173,13 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri:
# %% [markdown]
# ## Computing the barycentre
#
# In GPJax, the predictive distribution of a GP is given by a [Distrax](https://github.com/deepmind/distrax) distribution, making it straightforward to extract the mean vector and covariance matrix of each GP for learning a barycentre. We implement the fixed point scheme given in (3) in the following cell by utilising Jax's `vmap` operator to speed up large matrix operations using broadcasting in `tensordot`.
# In GPJax, the predictive distribution of a GP is given by a
# [Distrax](https://github.com/deepmind/distrax) distribution, making it
# straightforward to extract the mean vector and covariance matrix of each GP for
# learning a barycentre. We implement the fixed point scheme given in (3) in the
# following cell by utilising Jax's `vmap` operator to speed up large matrix operations
# using broadcasting in `tensordot`.


# %%
def sqrtm(A: jax.Array):
Expand All @@ -152,7 +204,12 @@ def step(covariance_candidate: jax.Array, idx: None):


# %% [markdown]
# With a function defined for learning a barycentre, we'll now compute it using the `lax.scan` operator that drastically speeds up for loops in Jax (see the [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)). The iterative update will be executed 100 times, with convergence measured by the difference between the previous and current iteration that we can confirm by inspecting the `sequence` array in the following cell.
# With a function defined for learning a barycentre, we'll now compute it using the
# `lax.scan` operator that drastically speeds up for loops in Jax (see the
# [Jax documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)).
# The iterative update will be executed 100 times, with convergence measured by the
# difference between the previous and current iteration that we can confirm by
# inspecting the `sequence` array in the following cell.

# %%
weights = jnp.ones((n_datasets,)) / n_datasets
Expand All @@ -173,7 +230,10 @@ def step(covariance_candidate: jax.Array, idx: None):
# %% [markdown]
# ## Plotting the result
#
# With a barycentre learned, we can visualise the result. We can see that the result looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and the uncertainty bands are sensible.
# With a barycentre learned, we can visualise the result. We can see that the result
# looks reasonable as it follows the sinusoidal curve of all the inferred GPs, and
# the uncertainty bands are sensible.


# %%
def plot(
Expand Down Expand Up @@ -206,10 +266,17 @@ def plot(
# %% [markdown]
# ## Displacement interpolation
#
# In the above example, we assigned uniform weights to each of the posteriors within the barycentre. In practice, we may have prior knowledge of which posterior is most likely to be the correct one. Regardless of the weights chosen, the barycentre remains a Gaussian process. We can interpolate between a pair of posterior distributions $\mu_1$ and $\mu_2$ to visualise the corresponding barycentre $\bar{\mu}$.
# In the above example, we assigned uniform weights to each of the posteriors within
# the barycentre. In practice, we may have prior knowledge of which posterior is most
# likely to be the correct one. Regardless of the weights chosen, the barycentre
# remains a Gaussian process. We can interpolate between a pair of posterior
# distributions $\mu_1$ and $\mu_2$ to visualise the corresponding
# barycentre $\bar{\mu}$.
#
# ![](figs/barycentre_gp.gif)

# %%

# %% [markdown]
# ## System configuration

Expand Down
Loading

0 comments on commit f427123

Please sign in to comment.