Skip to content

Commit

Permalink
Merge pull request #402 from JaxGaussianProcesses/jaxop_2
Browse files Browse the repository at this point in the history
  • Loading branch information
henrymoss authored Nov 7, 2023
2 parents ae9cfa3 + d255638 commit 5758238
Show file tree
Hide file tree
Showing 14 changed files with 708 additions and 631 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ Another way you can contribute to GPJax is through [issue
triaging](https://www.codetriage.com/what). This can include reproducing bug reports,
asking for vital information such as version numbers and reproduction instructions, or
identifying stale issues. If you would like to begin triaging issues, an easy way to get
started is to
started is to
[subscribe to GPJax on CodeTriage](https://www.codetriage.com/jaxgaussianprocesses/gpjax).

As a contributor to GPJax, you are expected to abide by our [code of
conduct](docs/CODE_OF_CONDUCT.md). If you are feel that you have either experienced or
witnessed behaviour that violates this standard, then we ask that you report any such
behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to
behaviours though [this form](https://jaxgaussianprocesses.com/contact/) or reach out to
one of the project's [_gardeners_](https://docs.jaxgaussianprocesses.com/GOVERNANCE/#roles).

Feel free to join our [Slack
Expand Down
7 changes: 2 additions & 5 deletions docs/examples/barycentres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,10 @@ def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF()) * likelihood

opt_posterior, _ = gpx.fit(
opt_posterior, _ = gpx.fit_scipy(
model=posterior,
objective=jax.jit(gpx.ConjugateMLL(negative=True)),
objective=gpx.ConjugateMLL(negative=True),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=500,
key=key,
)
latent_dist = opt_posterior.predict(xtest, train_data=D)
return opt_posterior.likelihood(latent_dist)
Expand Down
13 changes: 4 additions & 9 deletions docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
)
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
from simple_pytree import static_field
import tensorflow_probability.substrates.jax as tfp

Expand Down Expand Up @@ -214,13 +213,13 @@ def angular_distance(x, y, c):
return jnp.abs((x - y + c) % (c * 2) - c)


bij = tfb.Chain([tfb.Softplus(), tfb.Shift(np.array(4.0).astype(np.float64))])
bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))


@dataclass
class Polar(gpx.kernels.AbstractKernel):
period: float = static_field(2 * jnp.pi)
tau: float = param_field(jnp.array([4.0]), bijector=bij)
tau: float = param_field(jnp.array([5.0]), bijector=bij)

def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"]
Expand Down Expand Up @@ -267,14 +266,11 @@ def __call__(
likelihood = gpx.Gaussian(num_datapoints=n)
circular_posterior = gpx.Prior(mean_function=meanf, kernel=PKern) * likelihood

# Optimise GP's marginal log-likelihood using Adam
opt_posterior, history = gpx.fit(
# Optimise GP's marginal log-likelihood using BFGS
opt_posterior, history = gpx.fit_scipy(
model=circular_posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
key=key,
)

# %% [markdown]
Expand Down Expand Up @@ -314,7 +310,6 @@ def __call__(
ax.plot(angles, mu, label="Posterior mean")
ax.scatter(D.X, D.y, alpha=1, label="Observations")
ax.legend()

# %% [markdown]
# ## System configuration

Expand Down
8 changes: 2 additions & 6 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import optax as ox

with install_import_hook("gpjax", "beartype.beartype"):
import gpjax as gpx
Expand Down Expand Up @@ -134,7 +133,7 @@
# For this reason, we simply perform gradient descent on the GP's marginal
# log-likelihood term as in the
# [regression notebook](https://docs.jaxgaussianprocesses.com/examples/regression/).
# We do this using the Adam optimiser provided in `optax`.
# We do this using the BFGS optimiser provided in `scipy` via 'jaxopt'.

# %%
likelihood = gpx.Gaussian(num_datapoints=D.n)
Expand All @@ -155,13 +154,10 @@
# With a posterior defined, we can now optimise the model's hyperparameters.

# %%
opt_posterior, training_history = gpx.fit(
opt_posterior, training_history = gpx.fit_scipy(
model=posterior,
objective=jit(gpx.ConjugateMLL(negative=True)),
train_data=D,
optim=ox.adamw(learning_rate=0.01),
num_iters=1000,
key=key,
)

# %% [markdown]
Expand Down
21 changes: 7 additions & 14 deletions docs/examples/intro_to_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

config.update("jax_enable_x64", True)

from jax import jit
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import install_import_hook, Float
Expand Down Expand Up @@ -217,8 +216,8 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
mean = gpx.mean_functions.Zero()
kernel = gpx.kernels.Matern52(
lengthscale=jnp.array(2.0)
) # Initialise our kernel lengthscale to 2.0
lengthscale=jnp.array(0.1)
) # Initialise our kernel lengthscale to 0.1

prior = gpx.Prior(mean_function=mean, kernel=kernel)

Expand All @@ -235,16 +234,11 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(no_opt_posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=no_opt_posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=2000,
safe=True,
key=key,
)


Expand Down Expand Up @@ -524,7 +518,7 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
mean = gpx.mean_functions.Zero()
rbf_kernel = gpx.kernels.RBF(lengthscale=100.0)
periodic_kernel = gpx.kernels.Periodic()
linear_kernel = gpx.kernels.Linear()
linear_kernel = gpx.kernels.Linear(variance=0.001)
sum_kernel = gpx.kernels.SumKernel(kernels=[linear_kernel, periodic_kernel])
final_kernel = gpx.kernels.SumKernel(kernels=[rbf_kernel, sum_kernel])

Expand All @@ -540,18 +534,17 @@ def forrester(x: Float[Array, "N"]) -> Float[Array, "N"]:
# %%
negative_mll = gpx.objectives.ConjugateMLL(negative=True)
negative_mll(posterior, train_data=D)
negative_mll = jit(negative_mll)

opt_posterior, history = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=1000,
safe=True,
optim=ox.adamw(learning_rate=1e-2),
num_iters=500,
key=key,
)


# %% [markdown]
# Now we can obtain the model's prediction over a period of time which includes the
# training data, as well as 8 years before and after the training data:
Expand Down
20 changes: 4 additions & 16 deletions docs/examples/oceanmodelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from matplotlib import rcParams
import matplotlib.pyplot as plt
import optax as ox
import pandas as pd
import tensorflow_probability as tfp

Expand Down Expand Up @@ -239,30 +238,19 @@ def initialise_gp(kernel, mean, dataset):


# %% [markdown]
# With a model now defined, we can proceed to optimise the hyperparameters of our likelihood over $D_0$. This is done by minimising the MLL using `optax`. We also plot its value at each step to visually confirm that we have found the minimum. See the [introduction to Gaussian Processes](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/) notebook for more information on optimising the MLL.
# With a model now defined, we can proceed to optimise the hyperparameters of our likelihood over $D_0$. This is done by minimising the MLL using `BFGS`. We also plot its value at each step to visually confirm that we have found the minimum. See the [introduction to Gaussian Processes](https://docs.jaxgaussianprocesses.com/examples/intro_to_gps/) notebook for more information on optimising the MLL.


# %%
def optimise_mll(posterior, dataset, NIters=1000, key=key, plot_history=True):
def optimise_mll(posterior, dataset, NIters=1000, key=key):
# define the MLL using dataset_train
objective = gpx.objectives.ConjugateMLL(negative=True)
# Optimise to minimise the MLL
optimiser = ox.adam(learning_rate=0.1)
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=objective,
train_data=dataset,
optim=optimiser,
num_iters=NIters,
safe=True,
key=key,
)
# plot MLL value at each iteration
if plot_history:
fig, ax = plt.subplots(1, 1)
ax.plot(history, color=colors[1])
ax.set(xlabel="Training iteration", ylabel="Negative MLL")

return opt_posterior


Expand Down Expand Up @@ -471,7 +459,7 @@ def __call__(
# Redefine Gaussian process with Helmholtz kernel
kernel = HelmholtzKernel()
helmholtz_posterior = initialise_gp(kernel, mean, dataset_train)
# Optimise hyperparameters using optax
# Optimise hyperparameters using BFGS
opt_helmholtz_posterior = optimise_mll(helmholtz_posterior, dataset_train)


Expand Down
18 changes: 2 additions & 16 deletions docs/examples/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,16 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# We can now define an optimiser. For this example we'll use the `bfgs`
# optimiser.

# %%
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
safe=True,
key=key,
)

# %% [markdown]
# The calling of `fit` returns two objects: the optimised posterior and a history of
# training losses. We can plot the training loss to see how the optimisation has
# progressed.

# %%
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")

# %% [markdown]
# ## Prediction
#
Expand Down
18 changes: 2 additions & 16 deletions docs/examples/regression_mo.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,30 +228,16 @@
# accelerate training.

# %% [markdown]
# We can now define an optimiser with `optax`. For this example we'll use the `adam`
# We can now define an optimiser with `scipy`. For this example we'll use the `BFGS`
# optimiser.

# %%
opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.01),
num_iters=500,
safe=True,
key=key,
)

# %% [markdown]
# The calling of `fit` returns two objects: the optimised posterior and a history of
# training losses. We can plot the training loss to see how the optimisation has
# progressed.

# %%
fig, ax = plt.subplots()
ax.plot(history, color=cols[1])
ax.set(xlabel="Training iteration", ylabel="Negative marginal log likelihood")

# %% [markdown]
# ## Prediction
#
Expand Down
13 changes: 5 additions & 8 deletions docs/examples/yacht.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import optax as ox
import pandas as pd
from sklearn.metrics import (
mean_squared_error,
Expand Down Expand Up @@ -169,7 +168,9 @@
# %%
n_train, n_covariates = scaled_Xtr.shape
kernel = gpx.RBF(
active_dims=list(range(n_covariates)), lengthscale=np.ones((n_covariates,))
active_dims=list(range(n_covariates)),
variance=np.var(scaled_ytr),
lengthscale=0.1 * np.ones((n_covariates,)),
)
meanf = gpx.mean_functions.Zero()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
Expand All @@ -182,21 +183,17 @@
# ### Model Optimisation
#
# With a model now defined, we can proceed to optimise the hyperparameters of our
# model using Optax.
# model using Scipy.

# %%
training_data = gpx.Dataset(X=scaled_Xtr, y=scaled_ytr)

negative_mll = jit(gpx.ConjugateMLL(negative=True))
optimiser = ox.adamw(0.05)

opt_posterior, history = gpx.fit(
opt_posterior, history = gpx.fit_scipy(
model=posterior,
objective=negative_mll,
train_data=training_data,
optim=ox.adamw(learning_rate=0.05),
num_iters=500,
key=key,
)

# %% [markdown]
Expand Down
6 changes: 5 additions & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
)
from gpjax.citation import cite
from gpjax.dataset import Dataset
from gpjax.fit import fit
from gpjax.fit import (
fit,
fit_scipy,
)
from gpjax.gps import (
Prior,
construct_posterior,
Expand Down Expand Up @@ -87,6 +90,7 @@
"decision_making",
"kernels",
"fit",
"fit_scipy",
"Prior",
"construct_posterior",
"integrators",
Expand Down
Loading

0 comments on commit 5758238

Please sign in to comment.