Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpyro #122

Merged
merged 16 commits into from
Oct 16, 2022
Merged
2 changes: 1 addition & 1 deletion .github/workflows/workflow-master.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:
- "master"
pull_request:
branches:
- "master"
- "*"
jobs:
codecov:
name: Codecov Workflow
Expand Down
9 changes: 8 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ def find_version(*file_paths):
bibtex_bibfiles = ["refs.bib"]
bibtex_style = "unsrt"
bibtex_reference_style = "author_year"
nb_execution_mode = "auto"
nbsphinx_allow_errors = True
nbsphinx_custom_formats = {
".pct.py": ["jupytext.reads", {"fmt": "py:percent"}],
}
jupyter_execute_notebooks = "cache"

# Latex commands
# mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"
Expand Down Expand Up @@ -183,6 +185,11 @@ def find_version(*file_paths):
"logo_only": True,
"show_toc_level": 2,
"repository_url": "https://github.com/thomaspinder/GPJax/",
"launch_buttons": {
"binderhub_url": "https://mybinder.org",
"colab_url": "https://colab.research.google.com",
"notebook_interface": "jupyterlab",
},
"use_repository_button": True,
"use_sidenotes": True, # Turns footnotes into sidenotes - https://sphinx-book-theme.readthedocs.io/en/stable/content-blocks.html
}
Expand All @@ -199,10 +206,10 @@ def find_version(*file_paths):
# }



# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
# html_static_path = ["_static"]

html_static_path = ["_static"]
html_css_files = ["custom.css"]
1 change: 0 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ To learn more, checkout the [regression
notebook](https://gpjax.readthedocs.io/en/latest/examples/regression.html).
:::

Are you rendering

---

Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ jinja2
nbsphinx>=0.8.0
nb-black==1.0.7
matplotlib==3.3.3
tensorflow-probability>=0.16.0
sphinx-copybutton
networkx>=2.0.0
pandoc
Expand Down
36 changes: 18 additions & 18 deletions examples/barycentres.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.5
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# language: python
Expand All @@ -18,18 +18,19 @@
# # Gaussian Processes Barycentres
#
# In this notebook we'll give an implementation of <strong data-cite="mallasto2017learning"></strong>. In this work, the existence of a Wasserstein barycentre between a collection of Gaussian processes is proven. When faced with trying to _average_ a set of probability distributions, the Wasserstein barycentre is an attractive choice as it enables uncertainty amongst the individual distributions to be incorporated into the averaged distribution. When compared to a naive _mean of means_ and _mean of variances_ approach to computing the average probability distributions, it can be seen that Wasserstein barycentres offer significantly more favourable uncertainty estimation.
#

# %%
import typing as tp

import distrax as dx
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy.linalg as jsl
import matplotlib.pyplot as plt
import numpyro
import optax as ox

# %%
import gpjax as gpx

key = jr.PRNGKey(123)
Expand Down Expand Up @@ -99,24 +100,23 @@ def fit_gp(x: jnp.DeviceArray, y: jnp.DeviceArray):
D = gpx.Dataset(X=x, y=y)
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(kernel=gpx.RBF()) * likelihood
params, trainables, constrainers, unconstrainers = gpx.initialise(
parameter_state = gpx.initialise(
posterior, key
).unpack()
params = gpx.transform(params, unconstrainers)
)
params, trainables, bijectors = parameter_state.unpack()
params = gpx.unconstrain(params, bijectors)

objective = jax.jit(
posterior.marginal_log_likelihood(D, constrainers, negative=True)
posterior.marginal_log_likelihood(D, negative=True)
)

opt = ox.adam(learning_rate=0.01)
learned_params, training_history = gpx.fit(
objective=objective,
trainables=trainables,
params=params,
parameter_state=parameter_state,
optax_optim=opt,
n_iters=1000,
).unpack()
learned_params = gpx.transform(learned_params, constrainers)
return likelihood(posterior(D, learned_params)(xtest), learned_params)


Expand All @@ -134,9 +134,9 @@ def sqrtm(A: jnp.DeviceArray):


def wasserstein_barycentres(
distributions: tp.List[dx.Distribution], weights: jnp.DeviceArray
distributions: tp.List[numpyro.distributions.Distribution], weights: jnp.DeviceArray
):
covariances = [d.covariance() for d in distributions]
covariances = [d.covariance_matrix for d in distributions]
cov_stack = jnp.stack(covariances)
stack_sqrt = jax.vmap(sqrtm)(cov_stack)

Expand All @@ -156,7 +156,7 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray):
# %%
weights = jnp.ones((n_datasets,)) / n_datasets

means = jnp.stack([d.mean() for d in posterior_preds])
means = jnp.stack([d.mean for d in posterior_preds])
barycentre_mean = jnp.tensordot(weights, means, axes=1)

step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
Expand All @@ -167,7 +167,7 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray):
)


barycentre_process = dx.MultivariateNormalFullCovariance(
barycentre_process = numpyro.distributions.MultivariateNormal(
barycentre_mean, barycentre_covariance
)

Expand All @@ -179,16 +179,16 @@ def step(covariance_candidate: jnp.DeviceArray, i: jnp.DeviceArray):

# %%
def plot(
dist: dx.Distribution,
dist: numpyro.distributions.Distribution,
ax,
color: str = "tab:blue",
label: str = None,
ci_alpha: float = 0.2,
linewidth: float = 1.0,
):
mu = dist.mean()
sigma = dist.stddev()
ax.plot(xtest, dist.mean(), linewidth=linewidth, color=color, label=label)
mu = dist.mean
sigma = jnp.sqrt(dist.covariance_matrix.diagonal())
ax.plot(xtest, dist.mean, linewidth=linewidth, color=color, label=label)
ax.fill_between(
xtest.squeeze(), mu - sigma, mu + sigma, alpha=ci_alpha, color=color
)
Expand Down
59 changes: 26 additions & 33 deletions examples/classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.5
# jupytext_version: 1.11.2
# kernelspec:
# display_name: Python 3.9.7 ('gpjax')
# language: python
Expand All @@ -19,18 +19,16 @@
#
# In this notebook we demonstrate how to perform inference for Gaussian process models with non-Gaussian likelihoods via maximum a posteriori (MAP) and Markov chain Monte Carlo (MCMC). We focus on a classification task here and use [BlackJax](https://github.com/blackjax-devs/blackjax/) for sampling.

import blackjax
import distrax as dx

# %%
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib.pyplot as plt
import numpyro.distributions as npd
import optax as ox
from jaxtyping import Array, Float

import blackjax
import gpjax as gpx
from gpjax.utils import I

Expand Down Expand Up @@ -72,37 +70,33 @@
# To begin we obtain a set of initial parameter values through the `initialise` callable, and transform these to the unconstrained space via `transform` (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We also define the negative marginal log-likelihood, and JIT compile this to accelerate training.
# %%
parameter_state = gpx.initialise(posterior)
params, trainable, constrainer, unconstrainer = parameter_state.unpack()
params = gpx.transform(params, unconstrainer)
params, trainable, bijectors = parameter_state.unpack()

mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))
mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))

# %% [markdown]
# We can obtain a MAP estimate by optimising the marginal log-likelihood with Obtax's optimisers.
# %%
opt = ox.adam(learning_rate=0.01)
unconstrained_params, training_history = gpx.fit(
learned_params, training_history = gpx.fit(
mll,
params,
trainable,
parameter_state,
opt,
n_iters=500,
).unpack()

negative_Hessian = jax.jacfwd(jax.jacrev(mll))(unconstrained_params)["latent"][
negative_Hessian = jax.jacfwd(jax.jacrev(mll))(learned_params)["latent"][
"latent"
][:, 0, :, 0]

map_estimate = gpx.transform(unconstrained_params, constrainer)
# %% [markdown]
# From which we can make predictions at novel inputs, as illustrated below.
# %%
latent_dist = posterior(D, map_estimate)(xtest)
latent_dist = posterior(D, learned_params)(xtest)

predictive_dist = likelihood(latent_dist, map_estimate)
predictive_dist = likelihood(latent_dist, learned_params)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = jnp.sqrt(predictive_dist.variance)

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", label="Observations", color="tab:red")
Expand Down Expand Up @@ -139,7 +133,7 @@
# The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate.
# Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below.
# %%
f_map_estimate = posterior(D, map_estimate)(x).mean()
f_map_estimate = posterior(D, learned_params)(x).mean

jitter = 1e-6

Expand All @@ -150,7 +144,7 @@
L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True)
H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)

laplace_approximation = dx.MultivariateNormalFullCovariance(f_map_estimate, H_inv)
laplace_approximation = npd.MultivariateNormal(f_map_estimate, H_inv)

from gpjax.kernels import cross_covariance, gram

Expand All @@ -161,26 +155,26 @@


def predict(
laplace_at_data: dx.Distribution,
laplace_at_data: npd.Distribution,
train_data: Dataset,
test_inputs: Float[Array, "N D"],
jitter: int = 1e-6,
) -> dx.Distribution:
) -> npd.Distribution:
"""Compute the predictive distribution of the Laplace approximation at novel inputs.

Args:
laplace_at_data (dict): The Laplace approximation at the datapoints.

Returns:
dx.Distribution: The Laplace approximation at novel inputs.
npd.Distribution: The Laplace approximation at novel inputs.
"""
x, n = train_data.X, train_data.n

t = test_inputs
n_test = t.shape[0]

mu = laplace_at_data.mean().reshape(-1, 1)
cov = laplace_at_data.covariance()
mu = laplace_at_data.mean.reshape(-1, 1)
cov = laplace_at_data.covariance_matrix

Ktt = gram(prior.kernel, t, params["kernel"])
Kxx = gram(prior.kernel, x, params["kernel"])
Expand Down Expand Up @@ -214,7 +208,7 @@ def predict(
)
covariance += I(n_test) * jitter

return dx.MultivariateNormalFullCovariance(
return npd.MultivariateNormal(
jnp.atleast_1d(mean.squeeze()), covariance
)

Expand All @@ -224,10 +218,10 @@ def predict(
# %%
latent_dist = predict(laplace_approximation, D, xtest)

predictive_dist = likelihood(latent_dist, map_estimate)
predictive_dist = likelihood(latent_dist, learned_params)

predictive_mean = predictive_dist.mean()
predictive_std = predictive_dist.stddev()
predictive_mean = predictive_dist.mean
predictive_std = predictive_dist.variance**0.5

fig, ax = plt.subplots(figsize=(12, 5))
ax.plot(x, y, "o", label="Observations", color="tab:red")
Expand Down Expand Up @@ -274,9 +268,9 @@ def predict(
# %%
# Adapted from BlackJax's introduction notebook.
num_adapt = 500
num_samples = 500
num_samples = 200

mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))
mll = jax.jit(posterior.marginal_log_likelihood(D, negative=False))

adapt = blackjax.window_adaptation(
blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65
Expand Down Expand Up @@ -334,11 +328,10 @@ def one_step(state, rng_key):
ps["kernel"]["lengthscale"] = states.position["kernel"]["lengthscale"][i]
ps["kernel"]["variance"] = states.position["kernel"]["variance"][i]
ps["latent"] = states.position["latent"][i, :, :]
ps = gpx.transform(ps, constrainer)

latent_dist = posterior(D, ps)(xtest)
predictive_dist = likelihood(latent_dist, ps)
samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))
samples.append(predictive_dist.sample(key, sample_shape=(10,)))

samples = jnp.vstack(samples)

Expand Down
Loading