Skip to content

Commit

Permalink
Merge pull request #457 from JaxGaussianProcesses/nnx_update
Browse files Browse the repository at this point in the history
Nnx update
  • Loading branch information
thomaspinder authored Jul 1, 2024
2 parents c359936 + 65bd81a commit 9495947
Show file tree
Hide file tree
Showing 35 changed files with 380 additions and 299 deletions.
94 changes: 0 additions & 94 deletions docs/examples/README.md

This file was deleted.

42 changes: 29 additions & 13 deletions docs/examples/bayesian_optimisation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Introduction to Bayesian Optimisation
#
Expand Down Expand Up @@ -208,23 +225,22 @@ def standardised_forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:


# %%
from gpjax.parameters import Static


def return_optimised_posterior(
data: gpx.Dataset, prior: gpx.base.Module, key: Array
) -> gpx.base.Module:
data: gpx.Dataset, prior: gpx.gps.Prior, key: Array
) -> gpx.gps.AbstractPosterior:
# Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = gpx.likelihoods.Gaussian(
num_datapoints=data.n, obs_stddev=jnp.array(1e-6)
) # Our function is noise-free, so we set the observation noise's standard deviation to a very small value
likelihood = likelihood.replace_trainable(obs_stddev=False)
num_datapoints=data.n, obs_stddev=Static(jnp.array(1e-6))
)

posterior = prior * likelihood

negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=data)
negative_mll = jit(negative_mll)

opt_posterior, _ = gpx.fit(
model=posterior,
objective=negative_mll,
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
train_data=data,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
Expand All @@ -237,7 +253,7 @@ def return_optimised_posterior(


mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
kernel = gpx.kernels.Matern52(n_dims=1)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, key)

Expand Down Expand Up @@ -323,7 +339,7 @@ def optimise_sample(

# %%
def plot_bayes_opt(
posterior: gpx.base.Module,
posterior: gpx.gps.AbstractPosterior,
sample: FunctionalSample,
dataset: gpx.Dataset,
queried_x: ScalarFloat,
Expand Down Expand Up @@ -408,7 +424,7 @@ def plot_bayes_opt(

# Generate optimised posterior using previously observed data
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
kernel = gpx.kernels.Matern52(n_dims=1)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)
opt_posterior = return_optimised_posterior(D, prior, subkey)

Expand Down
25 changes: 17 additions & 8 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
# name: python3
# ---

# %%
# %load_ext autoreload
# %autoreload 2

# %% [markdown]
# # Classification
#
Expand All @@ -31,6 +35,7 @@

from time import time
import blackjax
from flax import nnx
import jax
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -131,6 +136,11 @@
key=key,
)

# %%
import jax

jax.__version__

# %% [markdown]
# From which we can make predictions at novel inputs, as illustrated below.

Expand Down Expand Up @@ -168,7 +178,6 @@
)

ax.legend()
plt.savefig("fit.png")
# %% [markdown]
# Here we projected the map estimates $\hat{\boldsymbol{f}}$ for the function values
# $\boldsymbol{f}$ at the data points $\boldsymbol{x}$ to get predictions over the
Expand Down Expand Up @@ -223,11 +232,13 @@
f_hat = Lx @ opt_posterior.latent.value

# Negative Hessian, H = -∇²p_tilde(y|f):
params, *static_state, graphdef = opt_posterior.split(gpx.parameters.Parameter, ...)
graphdef, params, *static_state = nnx.split(
opt_posterior, gpx.parameters.Parameter, ...
)


def loss(params, D):
model = graphdef.merge(params, *static_state)
model = nnx.merge(graphdef, params, *static_state)
return -gpx.objectives.log_posterior_density(model, D)


Expand Down Expand Up @@ -355,7 +366,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
num_adapt = 600
num_samples = 600

params, *static_state, graphdef = posterior.split(gpx.parameters.Parameter, ...)
graphdef, params, *static_state = nnx.split(posterior, gpx.parameters.Parameter, ...)
params_bijection = gpx.parameters.DEFAULT_BIJECTION

# Transform the parameters to the unconstrained space
Expand All @@ -364,7 +375,7 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma

def logprob_fn(params):
params = gpx.parameters.transform(params, params_bijection)
model = graphdef.merge(params, *static_state)
model = nnx.merge(graphdef, params, *static_state)
return gpx.objectives.log_posterior_density(model, D)


Expand Down Expand Up @@ -429,7 +440,6 @@ def one_step(state, rng_key):
ax0.set_title("Kernel Lengthscale")
ax1.set_title("Kernel Variance")
ax2.set_title("Latent Function (index = 1)")
plt.savefig("trace_plots.png")

# %% [markdown]
# ## Prediction
Expand All @@ -453,7 +463,7 @@ def one_step(state, rng_key):
for i in trange(0, num_samples, thin_factor, desc="Drawing posterior samples"):
sample_params = jtu.tree_map(lambda samples, i=i: samples[i], states.position)
sample_params = gpx.parameters.transform(sample_params, params_bijection)
model = graphdef.merge(sample_params, *static_state)
model = nnx.merge(graphdef, sample_params, *static_state)
latent_dist = model.predict(xtest, train_data=D)
predictive_dist = model.likelihood(latent_dist)
posterior_samples.append(predictive_dist.sample(seed=key, sample_shape=(10,)))
Expand Down Expand Up @@ -494,7 +504,6 @@ def one_step(state, rng_key):
linewidth=1,
)
ax.legend()
plt.savefig("mcmc_fit.png")

# %% [markdown]
# ## System configuration
Expand Down
20 changes: 18 additions & 2 deletions docs/examples/decision_making.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # Introduction to Decision Making with GPJax
#
Expand Down Expand Up @@ -137,7 +153,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:

# %%
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52()
kernel = gpx.kernels.Matern52(n_dims=1)
prior = gpx.gps.Prior(mean_function=mean, kernel=kernel)

# %% [markdown]
Expand Down Expand Up @@ -174,7 +190,7 @@ def forrester(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]:
posterior_handler = PosteriorHandler(
prior,
likelihood_builder=likelihood_builder,
optimization_objective=gpx.objectives.ConjugateMLL(negative=True),
optimization_objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
optimizer=ox.adam(learning_rate=0.01),
num_optimization_iters=1000,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/deep_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
field,
)

from flax.experimental import nnx
from flax import nnx
from gpjax.kernels.computations import (
AbstractKernelComputation,
DenseKernelComputation,
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/gpjax.mplstyle
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ xtick.direction: out
ytick.direction: out

# Colour palettes
axes.prop_cycle: cycler('color', ['2F83B4','B5121B', 'F77F00', '0B6E4F', '7A68A6', 'C5BB36', '8c564b', 'e377c2'])
axes.prop_cycle: cycler("color", ["2F83B4","B5121B", "F77F00", "0B6E4F", "7A68A6", "C5BB36", "8c564b", "e377c2"])
lines.color: B5121B
scatter.marker: x
image.cmap: inferno
Expand Down
16 changes: 16 additions & 0 deletions docs/examples/intro_to_gps.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# custom_cell_magics: kql
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.2
# kernelspec:
# display_name: gpjax
# language: python
# name: python3
# ---

# %% [markdown]
# # New to Gaussian Processes?
#
Expand Down
Loading

0 comments on commit 9495947

Please sign in to comment.