Skip to content

Commit

Permalink
Add copyright, update typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Sep 18, 2022
1 parent f1d3f81 commit cc1c318
Show file tree
Hide file tree
Showing 17 changed files with 769 additions and 190 deletions.
26 changes: 9 additions & 17 deletions examples/natgrads.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
"metadata": {},
"outputs": [],
"source": [
"z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n",
"z = jnp.linspace(-5.0, 5.0, 5000).reshape(-1, 1)\n",
"\n",
"fig, ax = plt.subplots(figsize=(12, 5))\n",
"ax.plot(x, y, \"o\", alpha=0.3)\n",
Expand Down Expand Up @@ -126,17 +126,9 @@
"\n",
"\n",
"natural_q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n",
"natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "60293d59",
"metadata": {},
"outputs": [],
"source": [
"params, trainables, bijectors = gpx.initialise(natural_svgp).unpack()"
"natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)\n",
"\n",
"parameter_state = gpx.initialise(natural_svgp)"
]
},
{
Expand All @@ -154,13 +146,13 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"inference_state = gpx.fit_natgrads(natural_svgp,\n",
" params,\n",
" trainables,\n",
" bijectors,\n",
" parameter_state=parameter_state,\n",
" train_data = D,\n",
" n_iters = 10000,\n",
" batch_size=1000,\n",
" n_iters = 4000,\n",
" batch_size=128,\n",
" key = jr.PRNGKey(42),\n",
" moment_optim = ox.sgd(1.0),\n",
" hyper_optim = ox.adam(1e-3),\n",
Expand Down
56 changes: 53 additions & 3 deletions gpjax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,22 @@
# Copyright 2022 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from jax.config import config

# Enable Float64 - this is crucial for more stable matrix inversions.
# Enable Float64 for more stable matrix inversions.
config.update("jax_enable_x64", True)
# Highlight any potentially unintended broadcasting rank promoting ops.
# config.update("jax_numpy_rank_promotion", "warn")

from .abstractions import fit, fit_batches, fit_natgrads
from .gps import Prior, construct_posterior
Expand All @@ -30,4 +43,41 @@
)
from .variational_inference import CollapsedVI, StochasticVI

__license__ = "MIT"
__description__ = "Didactic Gaussian processes in JAX"
__url__ = "https://github.com/thomaspinder/GPJax"
__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors"
__version__ = "0.4.13"


__all__ = [
"fit",
"fit_batches",
"fit_natgrads",
"Prior",
"construct_posterior",
"RBF",
"GraphKernel",
"Matern12",
"Matern32",
"Matern52",
"Polynomial",
"ProductKernel",
"SumKernel",
"Bernoulli",
"Gaussian",
"Constant",
"Zero",
"constrain",
"copy_dict_structure",
"initialise",
"unconstrain",
"Dataset",
"CollapsedVariationalGaussian",
"ExpectationVariationalGaussian",
"NaturalVariationalGaussian",
"VariationalGaussian",
"WhitenedVariationalGaussian",
"CollapsedVI",
"StochasticVI",
]
68 changes: 50 additions & 18 deletions gpjax/abstractions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
# Copyright 2022 The GPJax Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable, Dict, Optional

import jax
import jax.numpy as jnp
import jax.random as jr
import optax
import optax as ox
from chex import dataclass
from jax import lax
from jax.experimental import host_callback
Expand Down Expand Up @@ -100,18 +115,20 @@ def wrapper_progress_bar(carry, x):
def fit(
objective: Callable,
parameter_state: ParameterState,
optax_optim,
n_iters: int = 100,
log_rate: int = 10,
optax_optim: ox.GradientTransformation,
n_iters: Optional[int] = 100,
log_rate: Optional[int] = 10,
) -> InferenceState:
"""Abstracted method for fitting a GP model with respect to a supplied objective function.
Optimisers used here should originate from Optax.
Args:
objective (Callable): The objective function that we are optimising with respect to.
parameter_state (ParameterState): The initial parameter state.
optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set.
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
"""
Expand All @@ -135,7 +152,7 @@ def step(carry, iter_num):
params, opt_state = carry
loss_val, loss_gradient = jax.value_and_grad(loss)(params)
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
params = optax.apply_updates(params, updates)
params = ox.apply_updates(params, updates)
carry = params, opt_state
return carry, loss_val

Expand All @@ -153,14 +170,15 @@ def fit_batches(
objective: Callable,
parameter_state: ParameterState,
train_data: Dataset,
optax_optim,
optax_optim: ox.GradientTransformation,
key: PRNGKeyType,
batch_size: int,
n_iters: Optional[int] = 100,
log_rate: Optional[int] = 10,
) -> InferenceState:
"""Abstracted method for fitting a GP model with mini-batches respect to a supplied objective function.
Optimisers used here should originate from Optax.
Args:
objective (Callable): The objective function that we are optimising with respect to.
parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with.
Expand All @@ -170,6 +188,7 @@ def fit_batches(
batch_size(int): The batch_size.
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
"""
Expand All @@ -196,7 +215,7 @@ def step(carry, iter_num__and__key):

loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch)
updates, opt_state = optax_optim.update(loss_gradient, opt_state, params)
params = optax.apply_updates(params, updates)
params = ox.apply_updates(params, updates)

carry = params, opt_state
return carry, loss_val
Expand All @@ -211,9 +230,11 @@ def step(carry, iter_num__and__key):

def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset:
"""Batch the data into mini-batches.
Args:
train_data (Dataset): The training dataset.
batch_size (int): The batch size.
Returns:
Dataset: The batched dataset.
"""
Expand All @@ -226,33 +247,35 @@ def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset

def fit_natgrads(
stochastic_vi: StochasticVI,
params: Dict,
trainables: Dict,
bijectors: Dict,
parameter_state: ParameterState,
train_data: Dataset,
moment_optim,
hyper_optim,
moment_optim: ox.GradientTransformation,
hyper_optim: ox.GradientTransformation,
key: PRNGKeyType,
batch_size: int,
n_iters: Optional[int] = 100,
log_rate: Optional[int] = 10,
) -> Dict:
"""This is a training loop for natural gradients. See Salimbeni et al. (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models
Each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior.
Args:
stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training.
params (Dict): The parameters for which we would like to minimise our objective function with.
trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
bijectors (Dict): The bijectors to be applied to the parameters.
parameter_state (ParameterState): The initial parameter state.
train_data (Dataset): The training dataset.
batch_size(int): The batch_size.
moment_optim (GradientTransformation): The Optax optimiser for the natural gradient updates on the moments.
hyper_optim (GradientTransformation): The Optax optimiser for gradient updates on the hyperparameters.
key (PRNGKeyType): The PRNG key for the mini-batch sampling.
batch_size(int): The batch_size.
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
InferenceState: A dataclass comprising optimised parameters and training history.
"""

params, trainables, bijectors = parameter_state.unpack()

params = unconstrain(params, bijectors)

hyper_state = hyper_optim.init(params)
Expand All @@ -275,12 +298,12 @@ def step(carry, iter_num__and__key):
# Hyper-parameters update:
loss_val, loss_gradient = hyper_grads_fn(params, batch)
updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params)
params = optax.apply_updates(params, updates)
params = ox.apply_updates(params, updates)

# Natural gradients update:
loss_val, loss_gradient = nat_grads_fn(params, batch)
updates, moment_state = moment_optim.update(loss_gradient, moment_state, params)
params = optax.apply_updates(params, updates)
params = ox.apply_updates(params, updates)

carry = params, hyper_state, moment_state
return carry, loss_val
Expand All @@ -291,3 +314,12 @@ def step(carry, iter_num__and__key):
params = constrain(params, bijectors)
inf_state = InferenceState(params=params, history=history)
return inf_state


__all__ = [
"fit",
"fit_natgrads",
"get_batch",
"natural_gradients",
"progress_bar_scan",
]
Loading

0 comments on commit cc1c318

Please sign in to comment.