From 5c33301bdf7f86a4d756b86a1aca695350845b09 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 12:24:15 +0100 Subject: [PATCH 01/66] Initial commit. --- gpjax/abstractions.py | 22 +++++++++- gpjax/gps.py | 9 +---- gpjax/parameters.py | 63 +++++++---------------------- gpjax/variational_inference.py | 15 ++----- tests/test_abstractions.py | 26 ++++++------ tests/test_gp.py | 16 ++++---- tests/test_kernels.py | 10 ++--- tests/test_likelihoods.py | 6 +-- tests/test_mean_functions.py | 4 +- tests/test_parameters.py | 41 +++++++++---------- tests/test_variational_inference.py | 18 +++------ 11 files changed, 96 insertions(+), 134 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 7cff94bf..b68f9eb8 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -10,7 +10,7 @@ from jaxtyping import Array, Float from tqdm.auto import tqdm -from .parameters import trainable_params +from .parameters import trainable_params, transform from .types import Dataset, PRNGKeyType @@ -99,6 +99,7 @@ def fit( objective: tp.Callable, params: tp.Dict, trainables: tp.Dict, + bijectors: tp.Dict, optax_optim, n_iters: int = 100, log_rate: int = 10, @@ -109,6 +110,7 @@ def fit( objective (tp.Callable): The objective function that we are optimising with respect to. params (dict): The parameters for which we would like to minimise our objective function with. trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. + bijectors (dict): Dictionary of bijectors for each parameter. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. @@ -119,10 +121,14 @@ def fit( def loss(params): params = trainable_params(params, trainables) + params = transform(params, bijectors, forward=True) return objective(params) iter_nums = jnp.arange(n_iters) + # Tranform params to unconstrained space: + params = transform(params, bijectors, forward=False) + @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): params, opt_state = carry @@ -133,7 +139,12 @@ def step(carry, iter_num): return carry, loss_val (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) + + # Tranform params to constrained space: + params = transform(params, bijectors, forward=True) + inf_state = InferenceState(params=params, history=history) + return inf_state @@ -141,6 +152,7 @@ def fit_batches( objective: tp.Callable, params: tp.Dict, trainables: tp.Dict, + bijectors: tp.Dict, train_data: Dataset, optax_optim, key: PRNGKeyType, @@ -167,12 +179,16 @@ def fit_batches( opt_state = optax_optim.init(params) def loss(params, batch): + params = transform(params, bijectors, forward=True) params = trainable_params(params, trainables) return objective(params, batch) keys = jax.random.split(key, n_iters) iter_nums = jnp.arange(n_iters) + # Tranform params to unconstrained space: + params = transform(params, bijectors, forward=False) + @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num__and__key): iter_num, key = iter_num__and__key @@ -188,7 +204,11 @@ def step(carry, iter_num__and__key): return carry, loss_val (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) + + # Tranform params to constrained space: + params = transform(params, bijectors, forward=True) inf_state = InferenceState(params=params, history=history) + return inf_state diff --git a/gpjax/gps.py b/gpjax/gps.py index db13f2f5..b52cd244 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -20,7 +20,7 @@ NonConjugateLikelihoodType, ) from .mean_functions import AbstractMeanFunction, Zero -from .parameters import copy_dict_structure, evaluate_priors, transform +from .parameters import copy_dict_structure, evaluate_priors from .types import Dataset from .utils import I, concat_dictionaries @@ -194,7 +194,6 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: def marginal_log_likelihood( self, train_data: Dataset, - transformations: Dict, priors: dict = None, negative: bool = False, ) -> tp.Callable[[dict], Float[Array, "1"]]: @@ -202,7 +201,6 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters. priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. @@ -214,8 +212,6 @@ def marginal_log_likelihood( def mll( params: dict, ): - params = transform(params=params, transform_map=transformations) - # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] μx = self.prior.mean_function(x, params["mean_function"]) @@ -305,7 +301,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: def marginal_log_likelihood( self, train_data: Dataset, - transformations: Dict, priors: dict = None, negative: bool = False, ) -> tp.Callable[[dict], Float[Array, "1"]]: @@ -313,7 +308,6 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters. priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. @@ -327,7 +321,6 @@ def marginal_log_likelihood( priors["latent"] = dx.Normal(loc=0.0, scale=1.0) def mll(params: dict): - params = transform(params=params, transform_map=transformations) Kxx = gram(self.prior.kernel, x, params["kernel"]) Kxx += I(n) * self.jitter Lx = jnp.linalg.cholesky(Kxx) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 889c90e7..a18343c1 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -1,6 +1,5 @@ import typing as tp import warnings -from collections import namedtuple from copy import deepcopy from warnings import warn @@ -27,11 +26,10 @@ class ParameterState: params: tp.Dict trainables: tp.Dict - constrainers: tp.Dict - unconstrainers: tp.Dict + bijectors: tp.Dict def unpack(self): - return self.params, self.trainables, self.constrainers, self.unconstrainers + return self.params, self.trainables, self.bijectors def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: @@ -44,13 +42,12 @@ def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: _validate_kwargs(kwargs, params) for k, v in kwargs.items(): params[k] = merge_dictionaries(params[k], v) - constrainers, unconstrainers = build_transforms(params) + bijectors = build_bijectors(params) trainables = build_trainables(params) state = ParameterState( params=params, trainables=trainables, - constrainers=constrainers, - unconstrainers=unconstrainers, + bijectors=bijectors, ) return state @@ -92,8 +89,6 @@ def recursive_complete(d1: tp.Dict, d2: tp.Dict) -> tp.Dict: if type(value) is dict: if key in d2.keys(): recursive_complete(value, d2[key]) - # else: - # pass else: if key in d2.keys(): d1[key] = d2[key] @@ -144,54 +139,24 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def build_transforms(params: tp.Dict) -> tp.Tuple[tp.Dict, tp.Dict]: - """Using the bijector that is associated with each parameter, construct a pair of functions from the bijector that allow the parameter to be constrained and unconstrained. - - Args: - params (tp.Dict): The parameter set for which transformations should be derived from. - - Returns: - tp.Tuple[tp.Dict, tp.Dict]: A pair of dictionaries. The first dictionary maps each parameter to a function that constrains the parameter. The second dictionary maps each parameter to a function that unconstrains the parameter. - """ - - def forward(bijector): - return bijector.forward - - def inverse(bijector): - return bijector.inverse - - bijectors = build_bijectors(params) - - constrainers = jax.tree_util.tree_map(lambda _: forward, deepcopy(params)) - unconstrainers = jax.tree_util.tree_map(lambda _: inverse, deepcopy(params)) - - constrainers = jax.tree_util.tree_map(lambda f, b: f(b), constrainers, bijectors) - unconstrainers = jax.tree_util.tree_map( - lambda f, b: f(b), unconstrainers, bijectors - ) - - return constrainers, unconstrainers - - -def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict: +def transform(params: tp.Dict, bijectors: tp.Dict, forward: bool) -> tp.Dict: """Transform the parameters according to the constraining or unconstraining function dictionary. Args: params (tp.Dict): The parameters that are to be transformed. transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - tp.Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary. + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - warn( - "`transform` will be deprecated in a future release. As of v0.5.0, please use `constrain`" - " or `unconstrain` instead.", - DeprecationWarning, - stacklevel=2, - ) - return jax.tree_util.tree_map( - lambda param, trans: trans(param), params, transform_map - ) + + fwd = lambda param, trans: trans.forward(param) + inv = lambda param, trans: trans.inverse(param) + + map = fwd if forward else inv + + return jax.tree_util.tree_map(map, params, bijectors) ################################ diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index a81819c4..896c2e49 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -10,7 +10,6 @@ from .gps import AbstractPosterior from .kernels import cross_covariance, diagonal, gram from .likelihoods import Gaussian -from .parameters import transform from .quadrature import gauss_hermite_quadrature from .types import Dataset from .utils import I, concat_dictionaries @@ -41,13 +40,12 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict: @abc.abstractmethod def elbo( - self, train_data: Dataset, transformations: Dict + self, train_data: Dataset, ) -> Callable[[Dict], Float[Array, "1"]]: """Placeholder method for computing the evidence lower bound function (ELBO), given a training dataset and a set of transformations that map each parameter onto the entire real line. Args: train_data (Dataset): The training dataset for which the ELBO is to be computed. - transformations (Dict): A set of functions that unconstrain each parameter. Returns: Callable[[Array], Array]: A function that computes the ELBO given a set of parameters. @@ -65,13 +63,12 @@ def __post_init__(self): self.num_inducing = self.variational_family.num_inducing def elbo( - self, train_data: Dataset, transformations: Dict, negative: bool = False + self, train_data: Dataset, negative: bool = False ) -> Callable[[Float[Array, "N D"]], Float[Array, "1"]]: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. Args: train_data (Dataset): The training data for which we should maximise the ELBO with respect to. - transformations (Dict): The transformation set that unconstrains each parameter. negative (bool, optional): Whether or not the resultant elbo function should be negative. For gradient descent where we minimise our objective function this argument should be true as minimisation of the negative corresponds to maximisation of the ELBO. Defaults to False. Returns: @@ -80,8 +77,6 @@ def elbo( constant = jnp.array(-1.0) if negative else jnp.array(1.0) def elbo_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: - params = transform(params, transformations) - # KL[q(f(·)) || p(f(·))] kl = self.variational_family.prior_kl(params) @@ -142,13 +137,12 @@ def __post_init__(self): raise TypeError("Variational family must be CollapsedVariationalGaussian.") def elbo( - self, train_data: Dataset, transformations: Dict, negative: bool = False - ) -> Callable[[dict], Float[Array, "1"]]: + self, train_data: Dataset, negative: bool = False + ) -> Callable[[dict],Float[Array, "1"]:]: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. Args: train_data (Dataset): The training data for which we should maximise the ELBO with respect to. - transformations (Dict): The transformation set that unconstrains each parameter. negative (bool, optional): Whether or not the resultant elbo function should be negative. For gradient descent where we minimise our objective function this argument should be true as minimisation of the negative corresponds to maximisation of the ELBO. Defaults to False. Returns: @@ -161,7 +155,6 @@ def elbo( m = self.num_inducing def elbo_fn(params: Dict) -> Float[Array, "1"]: - params = transform(params, transformations) noise = params["likelihood"]["obs_noise"] z = params["variational_family"]["inducing_inputs"] Kzz = gram(self.prior.kernel, z, params["kernel"]) diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index b33127b1..b4d13596 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -4,8 +4,9 @@ import pytest import gpjax as gpx -from gpjax import RBF, Dataset, Gaussian, Prior, initialise, transform +from gpjax import RBF, Dataset, Gaussian, Prior, initialise from gpjax.abstractions import InferenceState, fit, fit_batches, get_batch +from gpjax.parameters import build_bijectors @pytest.mark.parametrize("n_iters", [10]) @@ -16,13 +17,12 @@ def test_fit(n_iters, n): y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 D = Dataset(X=x, y=y) p = Prior(kernel=RBF()) * Gaussian(num_datapoints=n) - params, trainable_status, constrainer, unconstrainer = initialise(p, key).unpack() - mll = p.marginal_log_likelihood(D, constrainer, negative=True) + params, trainables, bijectors = initialise(p, key).unpack() + mll = p.marginal_log_likelihood(D, negative=True) pre_mll_val = mll(params) optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(mll, params, trainable_status, optimiser, n_iters) + inference_state = fit(mll, params, trainables, bijectors, optimiser, n_iters) optimised_params, history = inference_state.params, inference_state.history - optimised_params = transform(optimised_params, constrainer) assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert mll(optimised_params) < pre_mll_val @@ -33,9 +33,10 @@ def test_fit(n_iters, n): def test_stop_grads(): params = {"x": jnp.array(3.0), "y": jnp.array(4.0)} trainables = {"x": True, "y": False} + bijectors = build_bijectors(params) loss_fn = lambda params: params["x"] ** 2 + params["y"] ** 2 optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(loss_fn, params, trainables, optimiser, n_iters=1) + inference_state = fit(loss_fn, params, trainables, bijectors, optimiser, n_iters=1) learned_params = inference_state.params assert isinstance(inference_state, InferenceState) assert learned_params["y"] == params["y"] @@ -58,23 +59,22 @@ def test_batch_fitting(n_iters, nb, ndata): q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainable_status, constrainer, unconstrainer = initialise( - svgp, key - ).unpack() - params = gpx.transform(params, unconstrainer) - objective = svgp.elbo(D, constrainer) + params, trainables, bijectors = initialise(svgp, key).unpack() + objective = svgp.elbo(D) + + pre_mll_val = objective(params, D) D = Dataset(X=x, y=y) optimiser = optax.adam(learning_rate=0.1) key = jr.PRNGKey(42) inference_state = fit_batches( - objective, params, trainable_status, D, optimiser, key, nb, n_iters + objective, params, trainables, bijectors, D, optimiser, key, nb, n_iters ) optimised_params, history = inference_state.params, inference_state.history - optimised_params = transform(optimised_params, constrainer) assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) + assert objective(optimised_params, D) < pre_mll_val assert isinstance(history, jnp.ndarray) assert history.shape[0] == n_iters diff --git a/tests/test_gp.py b/tests/test_gp.py index afede232..c393860f 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -22,7 +22,7 @@ def test_prior(num_datapoints): p = Prior(kernel=RBF()) parameter_state = initialise(p, jr.PRNGKey(123)) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() assert isinstance(p, Prior) assert isinstance(p, AbstractGP) prior_rv_fn = p(params) @@ -59,11 +59,11 @@ def test_conjugate_posterior(num_datapoints): assert isinstance(post2, AbstractGP) parameter_state = initialise(post, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() - params = transform(params, unconstrainer) + params, _, bijectors = parameter_state.unpack() + params = transform(params, bijectors, forward=False) # Marginal likelihood - mll = post.marginal_log_likelihood(train_data=D, transformations=constrainer) + mll = post.marginal_log_likelihood(train_data=D) objective_val = mll(params) assert isinstance(objective_val, jnp.DeviceArray) assert objective_val.shape == () @@ -101,12 +101,12 @@ def test_nonconjugate_posterior(num_datapoints, likel): assert isinstance(p, AbstractGP) parameter_state = initialise(post, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() - params = transform(params, unconstrainer) + params, _, bijectors = parameter_state.unpack() + params = transform(params, bijectors, forward=False) assert isinstance(parameter_state, ParameterState) # Marginal likelihood - mll = post.marginal_log_likelihood(train_data=D, transformations=constrainer) + mll = post.marginal_log_likelihood(train_data=D) objective_val = mll(params) assert isinstance(objective_val, jnp.DeviceArray) assert objective_val.shape == () @@ -130,7 +130,7 @@ def test_nonconjugate_posterior(num_datapoints, likel): def test_param_construction(num_datapoints, lik): p = Prior(kernel=RBF()) * lik(num_datapoints=num_datapoints) parameter_state = initialise(p, jr.PRNGKey(123)) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() if isinstance(lik, Bernoulli): assert sorted(list(params.keys())) == [ diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 7509fb46..efcc6fbe 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -36,7 +36,7 @@ def test_gram(kern, dim, fn): if dim > 1: x = jnp.hstack([x] * dim) parameter_state = initialise(kern, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() gram_matrix = fn(kern, x, params) assert gram_matrix.shape[0] == x.shape[0] assert gram_matrix.shape[0] == gram_matrix.shape[1] @@ -50,7 +50,7 @@ def test_cross_covariance(kern, n1, n2): x1 = jnp.linspace(-1.0, 1.0, num=n1).reshape(-1, 1) x2 = jnp.linspace(-1.0, 1.0, num=n2).reshape(-1, 1) parameter_state = initialise(kern, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() kernel_matrix = cross_covariance(kern, x1, x2, params) assert kernel_matrix.shape == (n1, n2) @@ -59,7 +59,7 @@ def test_cross_covariance(kern, n1, n2): def test_call(kernel): key = jr.PRNGKey(123) parameter_state = initialise(kernel, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() x, y = jnp.array([[1.0]]), jnp.array([[0.5]]) point_corr = kernel(x, y, params) assert isinstance(point_corr, jnp.DeviceArray) @@ -93,7 +93,7 @@ def test_initialisation(kernel, dim): else: kern = kernel(active_dims=[i for i in range(dim)]) parameter_state = initialise(kern, key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() assert list(params.keys()) == ["lengthscale", "variance"] assert all(params["lengthscale"] == jnp.array([1.0] * dim)) assert params["variance"] == jnp.array([1.0]) @@ -107,7 +107,7 @@ def test_initialisation(kernel, dim): def test_dtype(kernel): key = jr.PRNGKey(123) parameter_state = initialise(kernel(), key) - params, trainable_status, constrainer, unconstrainer = parameter_state.unpack() + params, _, _ = parameter_state.unpack() for k, v in params.items(): assert v.dtype == jnp.float64 diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index bcb71400..57f0ea09 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -25,7 +25,7 @@ def test_initialisers(num_datapoints, lik): key = jr.PRNGKey(123) lhood = lik(num_datapoints=num_datapoints) - params, _, _, _ = initialise(lhood, key).unpack() + params, _, _ = initialise(lhood, key).unpack() assert list(params.keys()) == true_initialisation[lhood.name] assert len(list(params.values())) == len(true_initialisation[lhood.name]) @@ -37,7 +37,7 @@ def test_predictive_moment(n): fmean = jr.uniform(key=key, shape=(n,)) * -1 fvar = jr.uniform(key=key, shape=(n,)) pred_mom_fn = lhood.predictive_moment_fn - params, _, _, _ = initialise(lhood, key).unpack() + params, _, _ = initialise(lhood, key).unpack() rv = pred_mom_fn(fmean, fvar, params) mu = rv.mean() sigma = rv.variance() @@ -51,7 +51,7 @@ def test_predictive_moment(n): def test_link_fns(lik: AbstractLikelihood, n: int): key = jr.PRNGKey(123) lhood = lik(num_datapoints=n) - params, _, _, _ = initialise(lhood, key).unpack() + params, _, _ = initialise(lhood, key).unpack() link_fn = lhood.link_function assert isinstance(link_fn, tp.Callable) x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index 79bddcd0..6fe6a9db 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -16,7 +16,7 @@ def test_shape(meanf, dim): x = jnp.linspace(-1.0, 1.0, num=10).reshape(-1, 1) if dim > 1: x = jnp.hstack([x] * dim) - params, _, _, _ = initialise(meanf, key).unpack() + params, _, _ = initialise(meanf, key).unpack() mu = meanf(x, params) assert mu.shape[0] == x.shape[0] assert mu.shape[1] == dim @@ -25,5 +25,5 @@ def test_shape(meanf, dim): @pytest.mark.parametrize("meanf", [Zero, Constant]) def test_initialisers(meanf): key = jr.PRNGKey(123) - params, _, _, _ = initialise(meanf(), key).unpack() + params, _, _ = initialise(meanf(), key).unpack() assert isinstance(params, tp.Dict) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index c0ce0c03..2ad0ace5 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -9,8 +9,8 @@ from gpjax.gps import Prior from gpjax.kernels import RBF from gpjax.likelihoods import Bernoulli, Gaussian -from gpjax.parameters import ( # build_all_transforms, - build_transforms, +from gpjax.parameters import ( + build_bijectors, copy_dict_structure, evaluate_priors, initialise, @@ -30,7 +30,7 @@ def test_initialise(lik): key = jr.PRNGKey(123) posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _, _ = initialise(posterior, key).unpack() + params, _, _ = initialise(posterior, key).unpack() assert list(sorted(params.keys())) == [ "kernel", "likelihood", @@ -40,7 +40,7 @@ def test_initialise(lik): def test_non_conjugate_initialise(): posterior = Prior(kernel=RBF()) * Bernoulli(num_datapoints=10) - params, _, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() assert list(sorted(params.keys())) == [ "kernel", "latent", @@ -64,7 +64,7 @@ def test_lpd(x): @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) def test_prior_template(lik): posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() prior_container = copy_dict_structure(params) for ( k, @@ -77,7 +77,7 @@ def test_prior_template(lik): @pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) def test_recursive_complete(lik): posterior = Prior(kernel=RBF()) * lik(num_datapoints=10) - params, _, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() priors = {"kernel": {}} priors["kernel"]["lengthscale"] = tfd.HalfNormal(scale=2.0) container = copy_dict_structure(params) @@ -167,7 +167,7 @@ def test_checks(num_datapoints): def test_structure_priors(): posterior = Prior(kernel=RBF()) * Gaussian(num_datapoints=10) - params, _, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() + params, _, _ = initialise(posterior, jr.PRNGKey(123)).unpack() priors = { "kernel": { "lengthscale": tfd.Gamma(1.0, 1.0), @@ -227,21 +227,18 @@ def test_prior_checks(latent_prior): @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) def test_output(num_datapoints, likelihood): posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - params, _, constrainer, unconstrainer = initialise( - posterior, jr.PRNGKey(123) - ).unpack() + params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() - assert isinstance(constrainer, dict) - assert isinstance(unconstrainer, dict) - for k, v1, v2 in recursive_items(constrainer, unconstrainer): - assert isinstance(v1, tp.Callable) - assert isinstance(v2, tp.Callable) + assert isinstance(bijectors, dict) + for k, v1, v2 in recursive_items(bijectors, bijectors): + assert isinstance(v1.forward, tp.Callable) + assert isinstance(v2.inverse, tp.Callable) - unconstrained_params = transform(params, unconstrainer) + unconstrained_params = transform(params, bijectors, forward=False) assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_params = transform(unconstrained_params, constrainer) + backconstrained_params = transform(unconstrained_params, bijectors, forward=True) for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype @@ -250,8 +247,8 @@ def test_output(num_datapoints, likelihood): augmented_params = params augmented_params["test_param"] = jnp.array([1.0]) - a_constrainers, a_unconstrainers = build_transforms(augmented_params) - assert "test_param" in list(a_constrainers.keys()) - assert "test_param" in list(a_unconstrainers.keys()) - assert a_constrainers["test_param"](jnp.array([1.0])) == 1.0 - assert a_unconstrainers["test_param"](jnp.array([1.0])) == 1.0 + a_bijectors = build_bijectors(augmented_params) + + assert "test_param" in list(a_bijectors.keys()) + assert a_bijectors["test_param"].forward(jnp.array([1.0])) == 1.0 + assert a_bijectors["test_param"].inverse(jnp.array([1.0])) == 1.0 diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py index 64042b99..732e1851 100644 --- a/tests/test_variational_inference.py +++ b/tests/test_variational_inference.py @@ -62,19 +62,16 @@ def test_stochastic_vi( assert svgp.posterior.prior == post.prior assert svgp.posterior.likelihood == post.likelihood - params, _, constrainer, unconstrainer = gpx.initialise( - svgp, jr.PRNGKey(123) - ).unpack() - params = gpx.transform(params, unconstrainer) + params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() assert svgp.prior == post.prior assert svgp.likelihood == post.likelihood assert svgp.num_inducing == n_inducing_points if jit_fns: - elbo_fn = jax.jit(svgp.elbo(D, constrainer)) + elbo_fn = jax.jit(svgp.elbo(D)) else: - elbo_fn = svgp.elbo(D, constrainer) + elbo_fn = svgp.elbo(D) assert isinstance(elbo_fn, tp.Callable) elbo_value = elbo_fn(params, D) assert isinstance(elbo_value, jnp.ndarray) @@ -103,19 +100,16 @@ def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): assert sgpr.posterior.prior == post.prior assert sgpr.posterior.likelihood == post.likelihood - params, _, constrainer, unconstrainer = gpx.initialise( - sgpr, jr.PRNGKey(123) - ).unpack() - params = gpx.transform(params, unconstrainer) + params, _, _ = gpx.initialise(sgpr, jr.PRNGKey(123)).unpack() assert sgpr.prior == post.prior assert sgpr.likelihood == post.likelihood assert sgpr.num_inducing == n_inducing_points if jit_fns: - elbo_fn = jax.jit(sgpr.elbo(D, constrainer)) + elbo_fn = jax.jit(sgpr.elbo(D)) else: - elbo_fn = sgpr.elbo(D, constrainer) + elbo_fn = sgpr.elbo(D) assert isinstance(elbo_fn, tp.Callable) elbo_value = elbo_fn(params) assert isinstance(elbo_value, jnp.ndarray) From 364b34ebdd5ca07aa134dd4ce600a45a85a707a0 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 13:16:54 +0100 Subject: [PATCH 02/66] Constrain, Unconstrain + Tests --- gpjax/__init__.py | 2 +- gpjax/abstractions.py | 26 ++++++++++++-------------- gpjax/parameters.py | 24 +++++++++++++++++++----- tests/test_gp.py | 6 ++---- tests/test_parameters.py | 7 ++++--- 5 files changed, 38 insertions(+), 27 deletions(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 985e7b1e..a897e610 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -19,7 +19,7 @@ ) from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero -from .parameters import copy_dict_structure, initialise, transform +from .parameters import constrain, copy_dict_structure, initialise, unconstrain from .types import Dataset from .variational_families import ( CollapsedVariationalGaussian, diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index b68f9eb8..5496962e 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -10,7 +10,7 @@ from jaxtyping import Array, Float from tqdm.auto import tqdm -from .parameters import trainable_params, transform +from .parameters import constrain, trainable_params, unconstrain from .types import Dataset, PRNGKeyType @@ -117,17 +117,18 @@ def fit( Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - opt_state = optax_optim.init(params) def loss(params): params = trainable_params(params, trainables) - params = transform(params, bijectors, forward=True) + params = constrain(params, bijectors) return objective(params) iter_nums = jnp.arange(n_iters) # Tranform params to unconstrained space: - params = transform(params, bijectors, forward=False) + params = unconstrain(params, bijectors) + + opt_state = optax_optim.init(params) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): @@ -141,7 +142,7 @@ def step(carry, iter_num): (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) # Tranform params to constrained space: - params = transform(params, bijectors, forward=True) + params = constrain(params, bijectors) inf_state = InferenceState(params=params, history=history) @@ -176,18 +177,16 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - opt_state = optax_optim.init(params) - def loss(params, batch): - params = transform(params, bijectors, forward=True) params = trainable_params(params, trainables) + params = constrain(params, bijectors) return objective(params, batch) - keys = jax.random.split(key, n_iters) - iter_nums = jnp.arange(n_iters) + params = unconstrain(params, bijectors) - # Tranform params to unconstrained space: - params = transform(params, bijectors, forward=False) + opt_state = optax_optim.init(params) + keys = jr.split(key, n_iters) + iter_nums = jnp.arange(n_iters) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num__and__key): @@ -205,8 +204,7 @@ def step(carry, iter_num__and__key): (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - # Tranform params to constrained space: - params = transform(params, bijectors, forward=True) + params = constrain(params, bijectors) inf_state = InferenceState(params=params, history=history) return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index a18343c1..69c3be80 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -139,8 +139,8 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def transform(params: tp.Dict, bijectors: tp.Dict, forward: bool) -> tp.Dict: - """Transform the parameters according to the constraining or unconstraining function dictionary. +def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the constrained space for corresponding bijectors. Args: params (tp.Dict): The parameters that are to be transformed. @@ -151,10 +151,24 @@ def transform(params: tp.Dict, bijectors: tp.Dict, forward: bool) -> tp.Dict: tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - fwd = lambda param, trans: trans.forward(param) - inv = lambda param, trans: trans.inverse(param) + map = lambda param, trans: trans.forward(param) - map = fwd if forward else inv + return jax.tree_util.tree_map(map, params, bijectors) + + +def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the unconstrained space for corresponding bijectors. + + Args: + params (tp.Dict): The parameters that are to be transformed. + transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + + Returns: + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + """ + + map = lambda param, trans: trans.inverse(param) return jax.tree_util.tree_map(map, params, bijectors) diff --git a/tests/test_gp.py b/tests/test_gp.py index c393860f..2ea2e780 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -5,7 +5,7 @@ import jax.random as jr import pytest -from gpjax import Dataset, initialise, transform +from gpjax import Dataset, initialise from gpjax.gps import ( AbstractGP, ConjugatePosterior, @@ -60,7 +60,6 @@ def test_conjugate_posterior(num_datapoints): parameter_state = initialise(post, key) params, _, bijectors = parameter_state.unpack() - params = transform(params, bijectors, forward=False) # Marginal likelihood mll = post.marginal_log_likelihood(train_data=D) @@ -101,8 +100,7 @@ def test_nonconjugate_posterior(num_datapoints, likel): assert isinstance(p, AbstractGP) parameter_state = initialise(post, key) - params, _, bijectors = parameter_state.unpack() - params = transform(params, bijectors, forward=False) + params, _, _ = parameter_state.unpack() assert isinstance(parameter_state, ParameterState) # Marginal likelihood diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 2ad0ace5..18fd2e3e 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -11,6 +11,7 @@ from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ( build_bijectors, + constrain, copy_dict_structure, evaluate_priors, initialise, @@ -19,7 +20,7 @@ recursive_complete, recursive_items, structure_priors, - transform, + unconstrain, ) @@ -234,11 +235,11 @@ def test_output(num_datapoints, likelihood): assert isinstance(v1.forward, tp.Callable) assert isinstance(v2.inverse, tp.Callable) - unconstrained_params = transform(params, bijectors, forward=False) + unconstrained_params = unconstrain(params, bijectors) assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_params = transform(unconstrained_params, bijectors, forward=True) + backconstrained_params = constrain(unconstrained_params, bijectors) for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype From d1f4d0c1fd43994cd76ac50c3d499c2ef536c95a Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 15:22:05 +0100 Subject: [PATCH 03/66] Parameter state (See comment) All notebooks are updated, except the tensorflow probability and MCMC section of the classification notebook. --- examples/barycentres.ipynb | 29 ++++++++-------- examples/collapsed_vi.ipynb | 40 +++++++++++----------- examples/graph_kernels.ipynb | 41 +++++++++++------------ examples/haiku.ipynb | 37 ++++++++++---------- examples/kernels.ipynb | 44 ++++++++++++------------ examples/regression.ipynb | 63 ++++++++++------------------------- examples/uncollapsed_vi.ipynb | 17 ++++------ examples/yacht.ipynb | 26 +++++++++------ gpjax/abstractions.py | 21 +++++------- gpjax/parameters.py | 2 +- tests/test_abstractions.py | 23 +++++++------ 11 files changed, 158 insertions(+), 185 deletions(-) diff --git a/examples/barycentres.ipynb b/examples/barycentres.ipynb index 94a719f7..ec210616 100644 --- a/examples/barycentres.ipynb +++ b/examples/barycentres.ipynb @@ -118,24 +118,23 @@ " if y.ndim == 1:\n", " y = y.reshape(-1, 1)\n", " D = gpx.Dataset(X=x, y=y)\n", + "\n", " likelihood = gpx.Gaussian(num_datapoints=n)\n", " posterior = gpx.Prior(kernel=gpx.RBF()) * likelihood\n", - " params, trainables, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()\n", - " params = gpx.transform(params, unconstrainers)\n", "\n", - " objective = jax.jit(posterior.marginal_log_likelihood(D, constrainers, negative=True))\n", + " parameter_state = gpx.initialise(posterior, key)\n", + " negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))\n", + " optimiser = ox.adam(learning_rate=0.01)\n", "\n", - " opt = ox.adam(learning_rate=0.01)\n", - " learned_params, training_history = gpx.fit(\n", - " objective=objective,\n", - " trainables=trainables,\n", - " params=params,\n", - " optax_optim=opt,\n", + " inference_state = gpx.fit(\n", + " objective=negative_mll,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", " n_iters=1000,\n", - " ).unpack()\n", - " learned_params = gpx.transform(learned_params, constrainers)\n", - " return likelihood(posterior(D, learned_params)(xtest), learned_params)\n", + " )\n", "\n", + " learned_params, training_history = inference_state.unpack()\n", + " return likelihood(posterior(D, learned_params)(xtest), learned_params)\n", "\n", "posterior_preds = [fit_gp(x, i) for i in ys]" ] @@ -279,7 +278,7 @@ "encoding": "# -*- coding: utf-8 -*-" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -293,11 +292,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/collapsed_vi.ipynb b/examples/collapsed_vi.ipynb index 4b7d51e8..eb0de2fc 100644 --- a/examples/collapsed_vi.ipynb +++ b/examples/collapsed_vi.ipynb @@ -186,22 +186,20 @@ "metadata": {}, "outputs": [], "source": [ - "params, trainables, constrainers, unconstrainers = gpx.initialise(sgpr, key).unpack()\n", + "parameter_state = gpx.initialise(sgpr, key)\n", "\n", - "loss_fn = jit(sgpr.elbo(D, constrainers, negative=True))\n", + "negative_elbo = jit(sgpr.elbo(D, negative=True))\n", "\n", "optimiser = ox.adam(learning_rate=0.005)\n", "\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n", - "learned_params, training_history = gpx.fit(\n", - " objective = loss_fn,\n", - " params = params,\n", - " trainables = trainables,\n", - " optax_optim = optimiser,\n", + "inference_state = gpx.fit(\n", + " objective=negative_elbo,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", " n_iters=2000,\n", - ").unpack()\n", - "learned_params = gpx.transform(learned_params, constrainers)" + ")\n", + "\n", + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -268,10 +266,10 @@ "outputs": [], "source": [ "full_rank_model = gpx.Prior(kernel = gpx.RBF()) * gpx.Gaussian(num_datapoints=D.n)\n", - "fr_params, fr_trainables, fr_constrainers, fr_unconstrainers = gpx.initialise(full_rank_model, key).unpack()\n", - "fr_params = gpx.transform(fr_params, fr_unconstrainers)\n", - "mll = jit(full_rank_model.marginal_log_likelihood(D, fr_constrainers, negative=True))\n", - "%timeit mll(fr_params).block_until_ready()" + "fr_params, *_ = gpx.initialise(full_rank_model, key).unpack()\n", + "negative_mll = jit(full_rank_model.marginal_log_likelihood(D, negative=True))\n", + "\n", + "%timeit negative_mll(fr_params).block_until_ready()" ] }, { @@ -281,8 +279,10 @@ "metadata": {}, "outputs": [], "source": [ - "sparse_elbo = jit(sgpr.elbo(D, constrainers, negative=True))\n", - "%timeit sparse_elbo(params).block_until_ready()" + "params, *_ = gpx.initialise(sgpr, key).unpack()\n", + "negative_elbo = jit(sgpr.elbo(D, negative=True))\n", + "\n", + "%timeit negative_elbo(params).block_until_ready()" ] }, { @@ -318,7 +318,7 @@ "custom_cell_magics": "kql" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -332,11 +332,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/graph_kernels.ipynb b/examples/graph_kernels.ipynb index d17ed493..af85f396 100644 --- a/examples/graph_kernels.ipynb +++ b/examples/graph_kernels.ipynb @@ -111,16 +111,16 @@ "x = jnp.arange(G.number_of_nodes()).reshape(-1, 1)\n", "\n", "kernel = gpx.GraphKernel(laplacian=L)\n", - "f = gpx.Prior(kernel=kernel)\n", + "prior = gpx.Prior(kernel=kernel)\n", "\n", - "true_params = f._initialise_params(key)\n", + "true_params = prior._initialise_params(key)\n", "true_params[\"kernel\"] = {\n", " \"lengthscale\": jnp.array(2.3),\n", " \"variance\": jnp.array(3.2),\n", " \"smoothness\": jnp.array(6.1),\n", "}\n", "\n", - "fx = f(true_params)(x)\n", + "fx = prior(true_params)(x)\n", "y = fx.sample(seed=key).reshape(-1, 1)\n", "\n", "D = gpx.Dataset(X=x, y=y)" @@ -173,23 +173,21 @@ "outputs": [], "source": [ "likelihood = gpx.Gaussian(num_datapoints=y.shape[0])\n", - "posterior = f * likelihood\n", - "params, trainable, constrainer, unconstrainer = gpx.initialise(posterior, key).unpack()\n", - "params = gpx.transform(params, unconstrainer)\n", + "posterior = prior * likelihood\n", "\n", - "mll = jit(\n", - " posterior.marginal_log_likelihood(train_data=D, transformations=constrainer, negative=True)\n", - ")\n", "\n", - "opt = ox.adam(learning_rate=0.01)\n", - "learned_params, training_history = gpx.fit(\n", - " objective=mll,\n", - " params=params,\n", - " trainables=trainable,\n", - " optax_optim=opt,\n", + "parameter_state = gpx.initialise(posterior, key)\n", + "negative_mll = jit(posterior.marginal_log_likelihood(train_data=D, negative=True))\n", + "optimiser = ox.adam(learning_rate=0.01)\n", + "\n", + "inference_state = gpx.fit(\n", + " objective=negative_mll,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", " n_iters=1000,\n", - ").unpack()\n", - "learned_params = gpx.transform(learned_params, constrainer)" + ")\n", + "\n", + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -213,7 +211,8 @@ "metadata": {}, "outputs": [], "source": [ - "initial_dist = likelihood(posterior(D, params)(x), params)\n", + "initial_params = parameter_state.params\n", + "initial_dist = likelihood(posterior(D, initial_params)(x), initial_params)\n", "predictive_dist = likelihood(posterior(D, learned_params)(x), learned_params)\n", "\n", "initial_mean = initial_dist.mean()\n", @@ -294,7 +293,7 @@ "encoding": "# -*- coding: utf-8 -*-" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -308,11 +307,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/haiku.ipynb b/examples/haiku.ipynb index d02b3b89..1c9abba2 100644 --- a/examples/haiku.ipynb +++ b/examples/haiku.ipynb @@ -173,10 +173,7 @@ "kernel.initialise(x, key)\n", "prior = gpx.Prior(kernel=kernel)\n", "likelihood = gpx.Gaussian(num_datapoints=D.n)\n", - "posterior = prior * likelihood\n", - "\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()\n", - "params = gpx.transform(params, unconstrainers)" + "posterior = prior * likelihood" ] }, { @@ -200,8 +197,10 @@ "metadata": {}, "outputs": [], "source": [ - "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainers, negative=True))\n", - "mll(params)\n", + "parameter_state = gpx.initialise(posterior, key)\n", + "\n", + "negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))\n", + "negative_mll(parameter_state.params)\n", "\n", "schedule = ox.warmup_cosine_decay_schedule(\n", " init_value=0.0,\n", @@ -211,19 +210,19 @@ " end_value=0.0,\n", ")\n", "\n", - "opt = ox.chain(\n", + "optimiser = ox.chain(\n", " ox.clip(1.0),\n", " ox.adamw(learning_rate=schedule),\n", ")\n", "\n", - "final_params, training_history = gpx.fit(\n", - " mll,\n", - " params,\n", - " trainables,\n", - " opt,\n", + "inference_state = gpx.fit(\n", + " objective=negative_mll,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", " n_iters=5000,\n", - ").unpack()\n", - "final_params = gpx.transform(final_params, constrainers)" + ")\n", + "\n", + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -243,8 +242,8 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dist = posterior(D, final_params)(xtest)\n", - "predictive_dist = likelihood(latent_dist, final_params)\n", + "latent_dist = posterior(D, learned_params)(xtest)\n", + "predictive_dist = likelihood(latent_dist, learned_params)\n", "\n", "predictive_mean = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", @@ -290,7 +289,7 @@ "custom_cell_magics": "kql" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -304,11 +303,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/kernels.ipynb b/examples/kernels.ipynb index b1186dde..b34a40f4 100644 --- a/examples/kernels.ipynb +++ b/examples/kernels.ipynb @@ -71,7 +71,7 @@ "\n", "for k, ax in zip(kernels, axes.ravel()):\n", " prior = gpx.Prior(kernel=k)\n", - " params, _, _, _ = gpx.initialise(prior, key).unpack()\n", + " params, *_ = gpx.initialise(prior, key).unpack()\n", " rv = prior(params)(x)\n", " y = rv.sample(sample_shape=10, seed=key)\n", "\n", @@ -303,14 +303,14 @@ " super().__init__()\n", " self.period = period\n", "```\n", - "As objects become increasingly large and complex, the conciseness of a dataclass becomes increasingly attractive. To ensure full compatability with Jax, it is crucial that the dataclass decorator is imported from Chex, not base Python's `dataclass` module. Functionally, the two objects are identical. However, unlike regular Python dataclasses, it is possilbe to apply operations such as`jit`, `vmap` and `grad` to the dataclasses given by Chex as they are registrered PyTrees. \n", + "As objects become increasingly large and complex, the conciseness of a dataclass becomes increasingly attractive. To ensure full compatability with Jax, it is crucial that the dataclass decorator is imported from Chex, not base Python's `dataclass` module. Functionally, the two objects are identical. However, unlike regular Python dataclasses, it is possilbe to apply operations such as `jit`, `vmap` and `grad` to the dataclasses given by Chex as they are registrered PyTrees. \n", "\n", "\n", "### Custom Parameter Bijection\n", "\n", "The constraint on $\\tau$ makes optimisation challenging with gradient descent. It would be much easier if we could instead parameterise $\\tau$ to be on the real line. Fortunately, this can be taken care of with GPJax's `add parameter` function, only requiring us to define the parameter's name and matching bijection (either a Distrax of TensorFlow probability bijector). Under the hood, calling this function updates a configuration object to register this parameter and its corresponding transform.\n", "\n", - "To define a bijector here we'll make use of the `Lambda` operator given in Distrax. This lets us convert any regular Jax function into a bijection. Given that we require $\\tau$ to be strictly greater than $4.$, we'll apply a [softplus transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) where the lower bound is shifted by $4.$." + "To define a bijector here we'll make use of the `Lambda` operator given in Distrax. This lets us convert any regular Jax function into a bijection. Given that we require $\\tau$ to be strictly greater than $4$, we'll apply a [softplus transformation](https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.softplus.html) where the lower bound is shifted by $4$." ] }, { @@ -322,10 +322,12 @@ "source": [ "from gpjax.config import add_parameter\n", "\n", - "bij_fn = lambda x: jax.nn.softplus(x+jnp.array(4.))\n", - "bij = dx.Lambda(bij_fn)\n", + "tau_bijector = dx.Lambda(\n", + " forward=lambda x: jnp.log(1 + jnp.exp(x + 4.0)),\n", + " inverse=lambda x: jnp.log(jnp.exp(x - 4.0) - 1.0),\n", + ")\n", "\n", - "add_parameter(\"tau\", bij)" + "add_parameter(\"tau\", tau_bijector)" ] }, { @@ -362,21 +364,21 @@ "likelihood = gpx.Gaussian(num_datapoints=n)\n", "circlular_posterior = gpx.Prior(kernel=PKern) * likelihood\n", "\n", - "# Initialise parameters and corresponding transformations\n", - "params, trainable, constrainer, unconstrainer = gpx.initialise(circlular_posterior, key).unpack()\n", + "# Initialise parameter state:\n", + "parameter_state = gpx.initialise(circlular_posterior, key)\n", "\n", "# Optimise GP's marginal log-likelihood using Adam\n", - "mll = jit(circlular_posterior.marginal_log_likelihood(D, constrainer, negative=True))\n", - "learned_params, training_history = gpx.fit(\n", - " mll,\n", - " params,\n", - " trainable,\n", - " adam(learning_rate=0.05),\n", + "negative_mll = jit(circlular_posterior.marginal_log_likelihood(D, negative=True))\n", + "optimiser = adam(learning_rate=0.05)\n", + "\n", + "inference_state = gpx.fit(\n", + " objective=negative_mll,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", " n_iters=1000,\n", - ").unpack()\n", + ")\n", "\n", - "# Untransform learned parameters\n", - "final_params = gpx.transform(learned_params, constrainer)" + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -396,7 +398,7 @@ "metadata": {}, "outputs": [], "source": [ - "posterior_rv = likelihood(circlular_posterior(D, final_params)(angles), final_params)\n", + "posterior_rv = likelihood(circlular_posterior(D, learned_params)(angles), learned_params)\n", "mu = posterior_rv.mean()\n", "one_sigma = posterior_rv.stddev()" ] @@ -459,7 +461,7 @@ "encoding": "# -*- coding: utf-8 -*-" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -473,11 +475,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/regression.ipynb b/examples/regression.ipynb index 2e9e8842..3a2e85c9 100644 --- a/examples/regression.ipynb +++ b/examples/regression.ipynb @@ -236,8 +236,7 @@ "|---|---|\n", "| `params` | Initial parameter values. |\n", "| `trainable` | Boolean dictionary that determines the training status of parameters (`True` for being trained and `False` otherwise). |\n", - "| `constrainer` | Transformations that map parameters from the _unconstrained space_ back to their original _constrained space_. |\n", - "| `unconstrainer` | Transformations that map parameters from their original _constrained space_ to an _unconstrained space_ for optimisation. |\n", + "| `bijectors` | Bijectors that can map parameters between the _unconstrained space_ and their original _constrained space_. |\n", "\n", "Further, upon calling `initialise`, we can state specific initial values for some, or all, of the parameters within our model. By default, the kernel lengthscale and variance and the likelihood's variance parameter are all initialised to 1. However, in the following cell, we'll demonstrate how the kernel lengthscale can be initialised to 0.5." ] @@ -272,7 +271,7 @@ "metadata": {}, "outputs": [], "source": [ - "params, trainable, constrainer, unconstrainer = parameter_state.unpack()\n", + "params, trainable, bijectors = parameter_state.unpack()\n", "pp.pprint(params)" ] }, @@ -281,17 +280,7 @@ "id": "28475bd7", "metadata": {}, "source": [ - "To motivate the purpose of `constrainer` and `unconstrainer` more precisely, notice that our model hyperparameters $\\{\\ell^2, \\sigma^2, \\alpha^2 \\}$ are all strictly positive. To ensure more stable optimisation, it is strongly advised to transform the parameters onto an unconstrained space first via `transform`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e8748952", - "metadata": {}, - "outputs": [], - "source": [ - "params = gpx.transform(params, unconstrainer)" + "To motivate the purpose the `bijectors` more precisely, notice that our model hyperparameters $\\{\\ell^2, \\sigma^2, \\alpha^2 \\}$ are all strictly positive, bijectors act to unconstrain these during the optimisation proceedure." ] }, { @@ -311,8 +300,8 @@ }, "outputs": [], "source": [ - "mll = jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))\n", - "mll(params)" + "negative_mll = jit(posterior.marginal_log_likelihood(D, negative=True))\n", + "negative_mll(params)" ] }, { @@ -338,12 +327,12 @@ "metadata": {}, "outputs": [], "source": [ - "opt = ox.adam(learning_rate=0.01)\n", + "optimiser = ox.adam(learning_rate=0.01)\n", + "\n", "inference_state = gpx.fit(\n", - " mll,\n", - " params,\n", - " trainable,\n", - " opt,\n", + " objective = negative_mll,\n", + " parameter_state = parameter_state,\n", + " optax_optim= optimiser,\n", " n_iters=500,\n", ")" ] @@ -363,27 +352,9 @@ "metadata": {}, "outputs": [], "source": [ - "final_params, training_history = inference_state.unpack()" - ] - }, - { - "cell_type": "markdown", - "id": "faef08b8", - "metadata": {}, - "source": [ + "learned_params, training_history = inference_state.unpack()\n", "\n", - "The exact value of our learned parameters is often useful in answering certain questions about the underlying process. To obtain these values, we untransfom our trained unconstrained parameters back to their original constrained space with `transform` and `constrainer`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "67015714", - "metadata": {}, - "outputs": [], - "source": [ - "final_params = gpx.transform(final_params, constrainer)\n", - "pp.pprint(final_params)" + "pp.pprint(learned_params)" ] }, { @@ -403,8 +374,8 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dist = posterior(D, final_params)(xtest)\n", - "predictive_dist = likelihood(latent_dist, final_params)\n", + "latent_dist = posterior(D, learned_params)(xtest)\n", + "predictive_dist = likelihood(latent_dist, learned_params)\n", "\n", "predictive_mean = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()" @@ -470,7 +441,7 @@ "custom_cell_magics": "kql" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -484,11 +455,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/examples/uncollapsed_vi.ipynb b/examples/uncollapsed_vi.ipynb index 39b8c036..17b282ff 100644 --- a/examples/uncollapsed_vi.ipynb +++ b/examples/uncollapsed_vi.ipynb @@ -221,10 +221,7 @@ "metadata": {}, "outputs": [], "source": [ - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp, key).unpack()\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n", - "loss_fn = jit(svgp.elbo(D, constrainers, negative=True))" + "negative_elbo = jit(svgp.elbo(D, negative=True))" ] }, { @@ -246,20 +243,20 @@ }, "outputs": [], "source": [ + "parameter_state = gpx.initialise(svgp, key)\n", "optimiser = ox.adam(learning_rate=0.01)\n", "\n", "inference_state = gpx.fit_batches(\n", - " objective = loss_fn,\n", - " params = params,\n", - " trainables = trainables,\n", + " objective = negative_elbo,\n", + " parameter_state= parameter_state,\n", " train_data = D, \n", " optax_optim = optimiser,\n", " n_iters=4000,\n", " key = jr.PRNGKey(42),\n", - " batch_size= 128\n", + " batch_size= 128,\n", ")\n", - "learned_params, training_history = inference_state.unpack()\n", - "learned_params = gpx.transform(learned_params, constrainers)" + "\n", + "learned_params, training_history = inference_state.unpack()" ] }, { diff --git a/examples/yacht.ipynb b/examples/yacht.ipynb index e4883e19..386c8d6c 100644 --- a/examples/yacht.ipynb +++ b/examples/yacht.ipynb @@ -166,10 +166,7 @@ "\n", "likelihood = gpx.Gaussian(num_datapoints=n_train)\n", "\n", - "posterior = prior * likelihood\n", - "\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()\n", - "params = gpx.transform(params, unconstrainers)" + "posterior = prior * likelihood" ] }, { @@ -189,9 +186,18 @@ "source": [ "training_data = gpx.Dataset(X = scaled_Xtr, y=scaled_ytr)\n", "\n", - "mll = jit(posterior.marginal_log_likelihood(train_data = training_data, transformations=constrainers, negative=True))\n", - "learned_params, training_history = gpx.fit(objective=mll, params=params, trainables=trainables, optax_optim=ox.adam(0.05), n_iters=1000, log_rate=50).unpack()\n", - "learned_params = gpx.transform(learned_params, constrainers)" + "parameter_state = gpx.initialise(posterior, key)\n", + "negative_mll = jit(posterior.marginal_log_likelihood(train_data = training_data, negative=True))\n", + "optimiser = ox.adam(0.05)\n", + "\n", + "inference_state = gpx.fit(objective=negative_mll, \n", + " parameter_state=parameter_state, \n", + " optax_optim=optimiser, \n", + " n_iters=1000, \n", + " log_rate=50,\n", + " )\n", + " \n", + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -303,7 +309,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -317,11 +323,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 5496962e..3d86c238 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -10,7 +10,7 @@ from jaxtyping import Array, Float from tqdm.auto import tqdm -from .parameters import constrain, trainable_params, unconstrain +from .parameters import ParameterState, constrain, trainable_params, unconstrain from .types import Dataset, PRNGKeyType @@ -97,9 +97,7 @@ def wrapper_progress_bar(carry, x): def fit( objective: tp.Callable, - params: tp.Dict, - trainables: tp.Dict, - bijectors: tp.Dict, + parameter_state: ParameterState, optax_optim, n_iters: int = 100, log_rate: int = 10, @@ -108,9 +106,7 @@ def fit( Optimisers used here should originate from Optax. Args: objective (tp.Callable): The objective function that we are optimising with respect to. - params (dict): The parameters for which we would like to minimise our objective function with. - trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. - bijectors (dict): Dictionary of bijectors for each parameter. + parameter_state (ParameterState): The initial parameter state. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. @@ -118,6 +114,8 @@ def fit( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params): params = trainable_params(params, trainables) params = constrain(params, bijectors) @@ -151,9 +149,7 @@ def step(carry, iter_num): def fit_batches( objective: tp.Callable, - params: tp.Dict, - trainables: tp.Dict, - bijectors: tp.Dict, + parameter_state: ParameterState, train_data: Dataset, optax_optim, key: PRNGKeyType, @@ -165,8 +161,7 @@ def fit_batches( Optimisers used here should originate from Optax. Args: objective (tp.Callable): The objective function that we are optimising with respect to. - params (dict): The parameters for which we would like to minimise our objective function with. - trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. + parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. train_data (Dataset): The training dataset. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. key (PRNGKeyType): The PRNG key for the mini-batch sampling. @@ -177,6 +172,8 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params, batch): params = trainable_params(params, trainables) params = constrain(params, bijectors) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 69c3be80..62cc40ec 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -22,7 +22,7 @@ ################################ @dataclass class ParameterState: - """The state of the model. This includes the parameter set and the functions that allow parameters to be constrained and unconstrained.""" + """The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained.""" params: tp.Dict trainables: tp.Dict diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index b4d13596..ad1a6394 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -6,7 +6,7 @@ import gpjax as gpx from gpjax import RBF, Dataset, Gaussian, Prior, initialise from gpjax.abstractions import InferenceState, fit, fit_batches, get_batch -from gpjax.parameters import build_bijectors +from gpjax.parameters import ParameterState, build_bijectors @pytest.mark.parametrize("n_iters", [10]) @@ -17,12 +17,12 @@ def test_fit(n_iters, n): y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 D = Dataset(X=x, y=y) p = Prior(kernel=RBF()) * Gaussian(num_datapoints=n) - params, trainables, bijectors = initialise(p, key).unpack() + parameter_state = initialise(p, key) mll = p.marginal_log_likelihood(D, negative=True) - pre_mll_val = mll(params) + pre_mll_val = mll(parameter_state.params) optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(mll, params, trainables, bijectors, optimiser, n_iters) - optimised_params, history = inference_state.params, inference_state.history + inference_state = fit(mll, parameter_state, optimiser, n_iters) + optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert mll(optimised_params) < pre_mll_val @@ -36,7 +36,10 @@ def test_stop_grads(): bijectors = build_bijectors(params) loss_fn = lambda params: params["x"] ** 2 + params["y"] ** 2 optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(loss_fn, params, trainables, bijectors, optimiser, n_iters=1) + parameter_state = ParameterState( + params=params, trainables=trainables, bijectors=bijectors + ) + inference_state = fit(loss_fn, parameter_state, optimiser, n_iters=1) learned_params = inference_state.params assert isinstance(inference_state, InferenceState) assert learned_params["y"] == params["y"] @@ -59,19 +62,19 @@ def test_batch_fitting(n_iters, nb, ndata): q = gpx.VariationalGaussian(prior=prior, inducing_inputs=z) svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainables, bijectors = initialise(svgp, key).unpack() + parameter_state = initialise(svgp, key) objective = svgp.elbo(D) - pre_mll_val = objective(params, D) + pre_mll_val = objective(parameter_state.params, D) D = Dataset(X=x, y=y) optimiser = optax.adam(learning_rate=0.1) key = jr.PRNGKey(42) inference_state = fit_batches( - objective, params, trainables, bijectors, D, optimiser, key, nb, n_iters + objective, parameter_state, D, optimiser, key, nb, n_iters ) - optimised_params, history = inference_state.params, inference_state.history + optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert objective(optimised_params, D) < pre_mll_val From 7eca9e7caeabfd12866973ae12692d1ef35286fe Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 15:58:46 +0100 Subject: [PATCH 04/66] Update nbs. --- examples/tfp_integration.ipynb | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index 09af5f7a..5daab7e5 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -104,7 +104,7 @@ "metadata": {}, "outputs": [], "source": [ - "params, _, constrainers, unconstrainers = gpx.initialise(posterior, key).unpack()" + "params, _, bijectors = gpx.initialise(posterior, key).unpack()" ] }, { @@ -201,9 +201,7 @@ "metadata": {}, "outputs": [], "source": [ - "mll = posterior.marginal_log_likelihood(\n", - " D, constrainers, priors=priors, negative=False\n", - ")\n", + "mll = posterior.marginal_log_likelihood(D, priors=priors, negative=False)\n", "mll(params)" ] }, @@ -225,6 +223,7 @@ "def build_log_pi(target, mapper_fn):\n", " def array_mll(parameter_array):\n", " parameter_dict = mapper_fn([jnp.array(i) for i in parameter_array])\n", + " gpx.constrain(parameter_dict, bijectors)\n", " return target(parameter_dict)\n", "\n", " return array_mll\n", @@ -252,7 +251,6 @@ "source": [ "n_samples = 500\n", "\n", - "\n", "def run_chain(key, state):\n", " kernel = tfp.mcmc.NoUTurnSampler(mll_array_form, 1e-1)\n", " return tfp.mcmc.sample_chain(\n", @@ -279,7 +277,8 @@ "metadata": {}, "outputs": [], "source": [ - "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))" ] }, { @@ -305,7 +304,7 @@ "\n", "samples = [states[burn_in:, i, :][::thin_factor] for i in range(n_params)]\n", "sample_dict = array_to_dict(samples)\n", - "constrained_samples = gpx.transform(sample_dict, constrainers)\n", + "constrained_samples = gpx.constrain(sample_dict, bijectors)\n", "constrained_sample_list = dict_to_array(constrained_samples)" ] }, @@ -400,7 +399,7 @@ "id": "bca69c91", "metadata": {}, "source": [ - "Since things look good, this concludes our tutorial on interfacing TensorFlow Probability with GPJax. \n", + "This concludes our tutorial on interfacing TensorFlow Probability with GPJax. \n", "The workflow demonstrated here only scratches the surface regarding the inference possible with a large number of samplers available in TensorFlow probability." ] }, @@ -429,7 +428,7 @@ "custom_cell_magics": "kql" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -443,11 +442,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, From e1d43ee216a651c85fbb4f9ca4a7cd9e114a730d Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 1 Sep 2022 08:17:48 +0000 Subject: [PATCH 05/66] Add functionality to transform ParameterState --- gpjax/parameters.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 62cc40ec..5aa2b31f 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -173,6 +173,26 @@ def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) +def transform_state( + state: ParameterState, projection_fn: tp.Callable[[tp.Dict, tp.Dict], tp.Dict] +) -> ParameterState: + """Transfrom the parameters of a `ParameterState` object using the corresponding set of bijectors. + The projection function accepts the parameters and corresponding bijectors as a dictionary input and + returns a dictionary of transformed parameters. + + Args: + state (ParameterState): The ParameterState object to be transformed. + projection_fn (tp.Callable[[tp.Dict, tp.Dict], tp.Dict]): The projection function that transforms the parameters. + + Returns: + ParameterState: A transformed ParameterState object. + """ + params = projection_fn(state.params, state.bijectors) + return ParameterState( + params=params, trainables=state.trainables, bijectors=state.bijectors + ) + + ################################ # Priors ################################ From da52be4f39388ee2485b649d8bb544efba398547 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 1 Sep 2022 08:29:27 +0000 Subject: [PATCH 06/66] Undo change --- gpjax/parameters.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 5aa2b31f..62cc40ec 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -173,26 +173,6 @@ def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) -def transform_state( - state: ParameterState, projection_fn: tp.Callable[[tp.Dict, tp.Dict], tp.Dict] -) -> ParameterState: - """Transfrom the parameters of a `ParameterState` object using the corresponding set of bijectors. - The projection function accepts the parameters and corresponding bijectors as a dictionary input and - returns a dictionary of transformed parameters. - - Args: - state (ParameterState): The ParameterState object to be transformed. - projection_fn (tp.Callable[[tp.Dict, tp.Dict], tp.Dict]): The projection function that transforms the parameters. - - Returns: - ParameterState: A transformed ParameterState object. - """ - params = projection_fn(state.params, state.bijectors) - return ParameterState( - params=params, trainables=state.trainables, bijectors=state.bijectors - ) - - ################################ # Priors ################################ From 979546e7865d042bc72ead6d312cbfbaa3bfdf08 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 1 Sep 2022 09:55:24 +0000 Subject: [PATCH 07/66] WIP for constrainers on state --- gpjax/abstractions.py | 33 ++++++++++++++-------------- gpjax/parameters.py | 47 ++++++++++++++++++++++++---------------- tests/test_parameters.py | 9 +++++--- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 3d86c238..7903a0e9 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -114,19 +114,18 @@ def fit( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - params, trainables, bijectors = parameter_state.unpack() - def loss(params): - params = trainable_params(params, trainables) - params = constrain(params, bijectors) + parameter_state = trainable_params(parameter_state) + parameter_state = constrain(parameter_state) + params = parameter_state.params return objective(params) iter_nums = jnp.arange(n_iters) # Tranform params to unconstrained space: - params = unconstrain(params, bijectors) + parameter_state = unconstrain(parameter_state) - opt_state = optax_optim.init(params) + opt_state = optax_optim.init(parameter_state.params) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): @@ -140,9 +139,9 @@ def step(carry, iter_num): (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) # Tranform params to constrained space: - params = constrain(params, bijectors) - - inf_state = InferenceState(params=params, history=history) + parameter_state.params = params + params = constrain(parameter_state) + inf_state = InferenceState(params=parameter_state.params, history=history) return inf_state @@ -172,15 +171,14 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - params, trainables, bijectors = parameter_state.unpack() - def loss(params, batch): - params = trainable_params(params, trainables) - params = constrain(params, bijectors) + parameter_state = trainable_params(parameter_state) + parameter_state = constrain(parameter_state) + params = parameter_state.params return objective(params, batch) - params = unconstrain(params, bijectors) - + parameter_state = unconstrain(parameter_state) + params = parameter_state.params opt_state = optax_optim.init(params) keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) @@ -201,8 +199,9 @@ def step(carry, iter_num__and__key): (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - params = constrain(params, bijectors) - inf_state = InferenceState(params=params, history=history) + parameter_state.params = params + parameter_state = constrain(parameter_state) + inf_state = InferenceState(params=parameter_state.params, history=history) return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 62cc40ec..33b30468 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -139,38 +139,43 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: - """Transform the parameters to the constrained space for corresponding bijectors. +def constrain(state: ParameterState) -> ParameterState: + """Transform the parameters to a constrained space using the corresponding set of bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been constrained. """ + params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.forward(param) - - return jax.tree_util.tree_map(map, params, bijectors) + transformed_params = jax.tree_util.tree_map(map, params, bijectors) + return ParameterState( + params=transformed_params, + trainables=state.trainables, + bijectors=bijectors, + ) -def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: - """Transform the parameters to the unconstrained space for corresponding bijectors. +def unconstrain(state: ParameterState) -> ParameterState: + """Transform the parameters to a unconstrained space using the corresponding set of bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been unconstrained. """ - + params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.inverse(param) - - return jax.tree_util.tree_map(map, params, bijectors) + transformed_params = jax.tree_util.tree_map(map, params, bijectors) + return ParameterState( + params=transformed_params, + trainables=state.trainables, + bijectors=bijectors, + ) ################################ @@ -270,8 +275,12 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict): return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: +def trainable_params(state: ParameterState) -> ParameterState: """Stop the gradients flowing through parameters whose trainable status is False""" - return jax.tree_util.tree_map( + params, trainables = state.params, state.trainables + trainable_params = jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables ) + return ParameterState( + params=trainable_params, trainables=trainables, bijectors=state.bijectors + ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 18fd2e3e..b1c5c84e 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -228,18 +228,21 @@ def test_prior_checks(latent_prior): @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) def test_output(num_datapoints, likelihood): posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() + state = initialise(posterior, jr.PRNGKey(123)) + params, _, bijectors = state.unpack() assert isinstance(bijectors, dict) for k, v1, v2 in recursive_items(bijectors, bijectors): assert isinstance(v1.forward, tp.Callable) assert isinstance(v2.inverse, tp.Callable) - unconstrained_params = unconstrain(params, bijectors) + unconstrained_state = unconstrain(state) + unconstrained_params = unconstrained_state.params assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_params = constrain(unconstrained_params, bijectors) + backconstrained_state = constrain(unconstrained_state) + backconstrained_params = backconstrained_state.params for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype From 8fcb1141071a9235080c7a36e3475c19916c21e9 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 2 Sep 2022 19:32:42 +0100 Subject: [PATCH 08/66] Revert "WIP for constrainers on state" This reverts commit 7d9ed4d8a392025c53a3fe1f7348eefe19504994. --- gpjax/abstractions.py | 33 ++++++++++++++-------------- gpjax/parameters.py | 47 ++++++++++++++++------------------------ tests/test_parameters.py | 9 +++----- 3 files changed, 39 insertions(+), 50 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 7903a0e9..3d86c238 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -114,18 +114,19 @@ def fit( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params): - parameter_state = trainable_params(parameter_state) - parameter_state = constrain(parameter_state) - params = parameter_state.params + params = trainable_params(params, trainables) + params = constrain(params, bijectors) return objective(params) iter_nums = jnp.arange(n_iters) # Tranform params to unconstrained space: - parameter_state = unconstrain(parameter_state) + params = unconstrain(params, bijectors) - opt_state = optax_optim.init(parameter_state.params) + opt_state = optax_optim.init(params) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): @@ -139,9 +140,9 @@ def step(carry, iter_num): (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) # Tranform params to constrained space: - parameter_state.params = params - params = constrain(parameter_state) - inf_state = InferenceState(params=parameter_state.params, history=history) + params = constrain(params, bijectors) + + inf_state = InferenceState(params=params, history=history) return inf_state @@ -171,14 +172,15 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params, batch): - parameter_state = trainable_params(parameter_state) - parameter_state = constrain(parameter_state) - params = parameter_state.params + params = trainable_params(params, trainables) + params = constrain(params, bijectors) return objective(params, batch) - parameter_state = unconstrain(parameter_state) - params = parameter_state.params + params = unconstrain(params, bijectors) + opt_state = optax_optim.init(params) keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) @@ -199,9 +201,8 @@ def step(carry, iter_num__and__key): (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - parameter_state.params = params - parameter_state = constrain(parameter_state) - inf_state = InferenceState(params=parameter_state.params, history=history) + params = constrain(params, bijectors) + inf_state = InferenceState(params=params, history=history) return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 33b30468..62cc40ec 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -139,43 +139,38 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def constrain(state: ParameterState) -> ParameterState: - """Transform the parameters to a constrained space using the corresponding set of bijectors. +def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the constrained space for corresponding bijectors. Args: - state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. + params (tp.Dict): The parameters that are to be transformed. + transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been constrained. + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.forward(param) - transformed_params = jax.tree_util.tree_map(map, params, bijectors) - return ParameterState( - params=transformed_params, - trainables=state.trainables, - bijectors=bijectors, - ) + return jax.tree_util.tree_map(map, params, bijectors) -def unconstrain(state: ParameterState) -> ParameterState: - """Transform the parameters to a unconstrained space using the corresponding set of bijectors. + +def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the unconstrained space for corresponding bijectors. Args: - state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. + params (tp.Dict): The parameters that are to be transformed. + transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been unconstrained. + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - params, bijectors = state.params, state.bijectors + map = lambda param, trans: trans.inverse(param) - transformed_params = jax.tree_util.tree_map(map, params, bijectors) - return ParameterState( - params=transformed_params, - trainables=state.trainables, - bijectors=bijectors, - ) + + return jax.tree_util.tree_map(map, params, bijectors) ################################ @@ -275,12 +270,8 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict): return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(state: ParameterState) -> ParameterState: +def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: """Stop the gradients flowing through parameters whose trainable status is False""" - params, trainables = state.params, state.trainables - trainable_params = jax.tree_util.tree_map( + return jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables ) - return ParameterState( - params=trainable_params, trainables=trainables, bijectors=state.bijectors - ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index b1c5c84e..18fd2e3e 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -228,21 +228,18 @@ def test_prior_checks(latent_prior): @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) def test_output(num_datapoints, likelihood): posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - state = initialise(posterior, jr.PRNGKey(123)) - params, _, bijectors = state.unpack() + params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() assert isinstance(bijectors, dict) for k, v1, v2 in recursive_items(bijectors, bijectors): assert isinstance(v1.forward, tp.Callable) assert isinstance(v2.inverse, tp.Callable) - unconstrained_state = unconstrain(state) - unconstrained_params = unconstrained_state.params + unconstrained_params = unconstrain(params, bijectors) assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_state = constrain(unconstrained_state) - backconstrained_params = backconstrained_state.params + backconstrained_params = constrain(unconstrained_params, bijectors) for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype From 1fb92ee190c1560da481f238b6fbf54109ff1289 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 4 Sep 2022 13:19:34 +0100 Subject: [PATCH 09/66] Test MCMC docs. --- examples/tfp_integration.ipynb | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index 5daab7e5..a9ede14c 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -277,8 +277,9 @@ "metadata": {}, "outputs": [], "source": [ - "unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))" + "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "#states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" ] }, { From fe1cf9085d470729097b27a034c666ae7210bea1 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 7 Sep 2022 08:42:08 +0100 Subject: [PATCH 10/66] Fix MCMC? --- examples/tfp_integration.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index a9ede14c..0c417b7a 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -223,7 +223,7 @@ "def build_log_pi(target, mapper_fn):\n", " def array_mll(parameter_array):\n", " parameter_dict = mapper_fn([jnp.array(i) for i in parameter_array])\n", - " gpx.constrain(parameter_dict, bijectors)\n", + " parameter_dict = gpx.constrain(parameter_dict, bijectors)\n", " return target(parameter_dict)\n", "\n", " return array_mll\n", @@ -277,8 +277,8 @@ "metadata": {}, "outputs": [], "source": [ - "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "#states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" ] }, From 5a01e63264d90c6a04d1d97251dff131c8db7b67 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 7 Sep 2022 09:43:20 +0100 Subject: [PATCH 11/66] Update classification.ipynb --- examples/classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 32cf9fa1..749cfb92 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -421,7 +421,7 @@ "\n", "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))\n", "\n", - "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", + "adapt = blackjax.window_adaptation(blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", "last_state, kernel, _ = adapt.run(key, params)\n", From ff39d46e833e58eaf817ef527949ed90919e5047 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 16 Sep 2022 14:17:55 +0100 Subject: [PATCH 12/66] Update variational_inference.py --- gpjax/variational_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 896c2e49..00e3740c 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -40,7 +40,8 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict: @abc.abstractmethod def elbo( - self, train_data: Dataset, + self, + train_data: Dataset, ) -> Callable[[Dict], Float[Array, "1"]]: """Placeholder method for computing the evidence lower bound function (ELBO), given a training dataset and a set of transformations that map each parameter onto the entire real line. @@ -138,7 +139,7 @@ def __post_init__(self): def elbo( self, train_data: Dataset, negative: bool = False - ) -> Callable[[dict],Float[Array, "1"]:]: + ) -> Callable[[dict], Float[Array, "1"]]: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. Args: From ce79de4f6a1fcaaa3411bd5a53014a4bfaf6964a Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 16 Sep 2022 14:50:34 +0100 Subject: [PATCH 13/66] Update classification.ipynb --- examples/classification.ipynb | 206 ++++++++++++++++------------------ 1 file changed, 99 insertions(+), 107 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 749cfb92..7ff76266 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -39,9 +39,9 @@ "source": [ "## Dataset\n", "\n", - "With the necessary modules imported, we simulate a dataset $\\mathcal{D} = (\\boldsymbol{x}, \\boldsymbol{y}) = \\{(x_i, y_i)\\}_{i=1}^{100}$ with inputs $\\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs\n", + "With the necessary modules imported, we simulate a dataset $\\mathcal{D} = (, \\boldsymbol{y}) = \\{(x_i, y_i)\\}_{i=1}^{100}$ with inputs $\\boldsymbol{x}$ sampled uniformly on $(-1., 1)$ and corresponding binary outputs\n", "\n", - "$$\\boldsymbol{y} = 0.5 * \\text{sign}(\\cos(2 * \\boldsymbol{x} + \\boldsymbol{\\epsilon})) + 0.5, \\quad \\boldsymbol{\\epsilon} \\sim \\mathcal{N} \\left(\\textbf{0}, \\textbf{I} * (0.05)^{2} \\right).$$\n", + "$$\\boldsymbol{y} = 0.5 * \\text{sign}(\\cos(2 * + \\boldsymbol{\\epsilon})) + 0.5, \\quad \\boldsymbol{\\epsilon} \\sim \\mathcal{N} \\left(\\textbf{0}, \\textbf{I} * (0.05)^{2} \\right).$$\n", "\n", "We store our data $\\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later." ] @@ -130,7 +130,7 @@ "lines_to_next_cell": 0 }, "source": [ - "To begin we obtain a set of initial parameter values through the `initialise` callable, and transform these to the unconstrained space via `transform` (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We also define the negative marginal log-likelihood, and JIT compile this to accelerate training." + "To begin we obtain an initial parameter state through the `initialise` callable (see the [regression notebook](https://gpjax.readthedocs.io/en/latest/nbs/regression.html)). We can obtain a MAP estimate by optimising the marginal log-likelihood with Optax's optimisers." ] }, { @@ -141,43 +141,18 @@ "outputs": [], "source": [ "parameter_state = gpx.initialise(posterior)\n", - "params, trainable, constrainer, unconstrainer = parameter_state.unpack()\n", - "params = gpx.transform(params, unconstrainer)\n", + "negative_mll = jax.jit(posterior.marginal_log_likelihood(D, negative=True))\n", "\n", - "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=True))" - ] - }, - { - "cell_type": "markdown", - "id": "e7d24f78", - "metadata": { - "lines_to_next_cell": 0 - }, - "source": [ - "We can obtain a MAP estimate by optimising the marginal log-likelihood with Obtax's optimisers." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "62001a7d", - "metadata": { - "lines_to_next_cell": 0 - }, - "outputs": [], - "source": [ - "opt = ox.adam(learning_rate=0.01)\n", - "unconstrained_params, training_history = gpx.fit(\n", - " mll,\n", - " params,\n", - " trainable,\n", - " opt,\n", - " n_iters=500,\n", - ").unpack()\n", + "optimiser = ox.adam(learning_rate=0.01)\n", "\n", - "negative_Hessian = jax.jacfwd(jax.jacrev(mll))(unconstrained_params)[\"latent\"][\"latent\"][:,0,:,0]\n", + "inference_state = gpx.fit(\n", + " objective=negative_mll,\n", + " parameter_state=parameter_state,\n", + " optax_optim=optimiser,\n", + " n_iters=1000,\n", + ")\n", "\n", - "map_estimate = gpx.transform(unconstrained_params, constrainer)" + "map_estimate , training_history = inference_state.unpack()" ] }, { @@ -197,11 +172,12 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dist = posterior(D, map_estimate)(xtest)\n", + "map_latent_dist = posterior(D, map_estimate)(xtest)\n", "\n", - "predictive_dist = likelihood(latent_dist, map_estimate)\n", + "predictive_dist = likelihood(map_latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", + "a = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", @@ -221,6 +197,18 @@ "ax.legend()" ] }, + { + "cell_type": "markdown", + "id": "0814299a", + "metadata": {}, + "source": [ + "Here we projected the map estimates $\\hat{\\boldsymbol{f}}$ for the function values $\\boldsymbol{f}$ at the data points $\\boldsymbol{x}$ to get predictions over the whole domain,\n", + "\n", + "\\begin{align}\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{map}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) \\delta(\\boldsymbol{f} - \\hat{\\boldsymbol{f}}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "\\end{align}" + ] + }, { "cell_type": "markdown", "id": "219d35b2", @@ -239,8 +227,18 @@ }, "source": [ "## Laplace approximation\n", - "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate.\n", - "Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." + "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = p(\\boldsymbol{y}|\\boldsymbol{f}) p(\\boldsymbol{f})$ as the unormalised posterior for function values $\\boldsymbol{f}$ at the datapoints $\\boldsymbol{x}$, we can expand the log of this about the posterior mode $\\hat{\\boldsymbol{f}}$ via a Taylor expansion. This gives:\n", + "\n", + "\\begin{align}\n", + "\\log\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) + \\left[\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})|_{\\hat{\\boldsymbol{f}}}\\right]^{T} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\mathcal{O}(\\lVert \\boldsymbol{f} - \\hat{\\boldsymbol{f}} \\rVert^3).\n", + "\\end{align}\n", + "\n", + "Now since $\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})$ is zero at the mode, this suggests the following approximation\n", + "\\begin{align}\n", + "\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) \\approx \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) \\exp\\left\\{ \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) \\right\\}\n", + "\\end{align},\n", + "\n", + "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx q(\\boldsymbol{f}) := \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." ] }, { @@ -250,18 +248,27 @@ "metadata": {}, "outputs": [], "source": [ - "f_map_estimate = posterior(D, map_estimate)(x).mean()\n", - "\n", + "from gpjax.kernels import gram, cross_covariance\n", "jitter = 1e-6\n", "\n", + "# Compute (latent) function value map estimates at training points:\n", + "Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", + "Kxx += I(D.n) * jitter\n", + "Lx = jnp.linalg.cholesky(Kxx)\n", + "f_hat = jnp.matmul(Lx, map_estimate[\"latent\"])\n", + "\n", + "# Negative Hessian, H = -∇²p_tilde(y|f):\n", + "H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)[\"latent\"][\"latent\"][:,0,:,0]\n", + "\n", "# LLᵀ = H\n", - "L = jnp.linalg.cholesky(negative_Hessian + I(D.n) * jitter)\n", + "L = jnp.linalg.cholesky(H + I(D.n) * jitter)\n", "\n", "# H⁻¹ = H⁻¹ I = (LLᵀ)⁻¹ I = L⁻ᵀL⁻¹ I\n", "L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True)\n", "H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)\n", "\n", - "laplace_approximation = dx.MultivariateNormalFullCovariance(f_map_estimate, H_inv)" + "# p(f|D) ≈ N(f_hat, H⁻¹) \n", + "laplace_approximation = dx.MultivariateNormalFullCovariance(jnp.atleast_1d(f_hat.squeeze()), H_inv)" ] }, { @@ -271,70 +278,44 @@ "lines_to_next_cell": 0 }, "source": [ - "For novel inputs, we must interpolate the above distribution, which can be achived via the function defined below." + "For novel inputs, we must project the above approximating distribution through the Gaussian conditional distribution $p(f(\\cdot)| \\boldsymbol{f})$,\n", + "\n", + "\\begin{align}\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} (\\mathbf{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "\\end{align}\n", + "\n", + "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b6cb53e", - "metadata": { - "lines_to_next_cell": 0 - }, + "id": "0f6c2bcc", + "metadata": {}, "outputs": [], "source": [ - "from gpjax.types import Dataset\n", - "from gpjax.kernels import gram, cross_covariance\n", - "\n", - "\n", - "def predict(laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: Float[Array, \"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n", - " \"\"\"Compute the predictive distribution of the Laplace approximation at novel inputs.\n", - "\n", - " Args:\n", - " laplace_at_data (dict): The Laplace approximation at the datapoints.\n", - "\n", - " Returns:\n", - " dx.Distribution: The Laplace approximation at novel inputs.\n", - " \"\"\"\n", - " x, n = train_data.X, train_data.n\n", - "\n", - " t = test_inputs\n", - " n_test = t.shape[0]\n", - "\n", - " mu = laplace_at_data.mean().reshape(-1, 1)\n", - " cov = laplace_at_data.covariance()\n", - "\n", - " Ktt = gram(prior.kernel, t, params[\"kernel\"])\n", - " Kxx = gram(prior.kernel, x, params[\"kernel\"])\n", - " Kxt = cross_covariance(prior.kernel, x, t, params[\"kernel\"])\n", - " μt = prior.mean_function(t, params[\"mean_function\"])\n", - " μx = prior.mean_function(x, params[\"mean_function\"])\n", - "\n", - " # Lx Lxᵀ = Kxx\n", - " Lx = jnp.linalg.cholesky(Kxx + I(n) * jitter)\n", + "def construct_laplace(test_inputs: Float[Array, \"N D\"]) -> dx.MultivariateNormalFullCovariance:\n", + " \n", + " map_latent_dist = posterior(D, map_estimate)(test_inputs)\n", "\n", - " # sqrt sqrtᵀ = Σ\n", - " sqrt = jnp.linalg.cholesky(cov + I(n) * jitter)\n", + " Kxt = cross_covariance(prior.kernel, x, test_inputs, map_estimate[\"kernel\"])\n", + " Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", + " Kxx += I(D.n) * jitter\n", + " Lx = jnp.linalg.cholesky(Kxx)\n", "\n", - " # Lz⁻¹ Kxt\n", - " Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", + " # Lx⁻¹ Kxt\n", + " Lx_inv_Ktx = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", "\n", " # Kxx⁻¹ Kxt\n", - " Kxx_inv_Kxt = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Kxt, lower=False)\n", + " Kxx_inv_Ktx = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Ktx, lower=False)\n", "\n", - " # Ktx Kxx⁻¹ sqrt\n", - " Ktx_Kxx_inv_sqrt = jnp.matmul(Kxx_inv_Kxt.T, sqrt)\n", - " \n", - " # μt + Ktx Kxx⁻¹ (μ - μx)\n", - " mean = μt + jnp.matmul(Kxx_inv_Kxt.T, mu - μx)\n", + " # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt\n", + " laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)\n", "\n", - " # Ktt - Ktx Kxx⁻¹ Kxt + Ktx Kxx⁻¹ S Kxx⁻¹ Kxt\n", - " covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + jnp.matmul(Ktx_Kxx_inv_sqrt, Ktx_Kxx_inv_sqrt.T)\n", - " covariance += I(n_test) * jitter\n", + " mean = map_latent_dist.mean()\n", + " covariance = map_latent_dist.covariance() + laplace_cov_term\n", "\n", - " return dx.MultivariateNormalFullCovariance(\n", - " jnp.atleast_1d(mean.squeeze()), covariance\n", - " )" + " return dx.MultivariateNormalFullCovariance(jnp.atleast_1d(mean.squeeze()), covariance)" ] }, { @@ -356,9 +337,8 @@ }, "outputs": [], "source": [ - "latent_dist = predict(laplace_approximation, D, xtest)\n", - "\n", - "predictive_dist = likelihood(latent_dist, map_estimate)\n", + "laplace_latent_dist = construct_laplace(xtest)\n", + "predictive_dist = likelihood(laplace_latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", @@ -376,7 +356,6 @@ ")\n", "ax.plot(xtest, predictive_mean - predictive_std, color=\"tab:blue\", linestyle=\"--\", linewidth=1)\n", "ax.plot(xtest, predictive_mean + predictive_std, color=\"tab:blue\", linestyle=\"--\", linewidth=1)\n", - "\n", "ax.legend()" ] }, @@ -411,7 +390,17 @@ { "cell_type": "code", "execution_count": null, - "id": "43dbf9b6", + "id": "a0bd8213", + "metadata": {}, + "outputs": [], + "source": [ + "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a990f322", "metadata": {}, "outputs": [], "source": [ @@ -419,13 +408,16 @@ "num_adapt = 500\n", "num_samples = 500\n", "\n", - "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))\n", + "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n", + "mll = posterior.marginal_log_likelihood(D, negative=False)\n", + "\n", + "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", "\n", "adapt = blackjax.window_adaptation(blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", - "last_state, kernel, _ = adapt.run(key, params)\n", - "\n", + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", "\n", "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", " def one_step(state, rng_key):\n", @@ -520,7 +512,7 @@ " ps[\"kernel\"][\"lengthscale\"] = states.position[\"kernel\"][\"lengthscale\"][i]\n", " ps[\"kernel\"][\"variance\"] = states.position[\"kernel\"][\"variance\"][i]\n", " ps[\"latent\"] = states.position[\"latent\"][i, :, :]\n", - " ps = gpx.transform(ps, constrainer)\n", + " ps = gpx.constrain(ps, bijectors)\n", "\n", " latent_dist = posterior(D, ps)(xtest)\n", " predictive_dist = likelihood(latent_dist, ps)\n", @@ -586,7 +578,7 @@ "custom_cell_magics": "kql" }, "kernelspec": { - "display_name": "Python 3.9.7 ('gpjax')", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -600,11 +592,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "920091140e6b97de16b405af485d142952a229f5dad61a888f46227f5acb94cf" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, From 97f9d9ff7c18ac932eb115bfcf95e0d791d276e2 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 12:24:15 +0100 Subject: [PATCH 14/66] Initial commit. --- gpjax/abstractions.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 3d86c238..d26a54b4 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -11,6 +11,7 @@ from tqdm.auto import tqdm from .parameters import ParameterState, constrain, trainable_params, unconstrain +from .parameters import trainable_params, transform from .types import Dataset, PRNGKeyType @@ -175,6 +176,7 @@ def fit_batches( params, trainables, bijectors = parameter_state.unpack() def loss(params, batch): + params = transform(params, bijectors, forward=True) params = trainable_params(params, trainables) params = constrain(params, bijectors) return objective(params, batch) @@ -185,6 +187,9 @@ def loss(params, batch): keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) + # Tranform params to unconstrained space: + params = transform(params, bijectors, forward=False) + @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num__and__key): iter_num, key = iter_num__and__key From 2b617da3b2c505b77205d9a71c085e83a85b3924 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 13:16:54 +0100 Subject: [PATCH 15/66] Constrain, Unconstrain + Tests --- gpjax/abstractions.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index d26a54b4..0a9522c5 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -176,7 +176,6 @@ def fit_batches( params, trainables, bijectors = parameter_state.unpack() def loss(params, batch): - params = transform(params, bijectors, forward=True) params = trainable_params(params, trainables) params = constrain(params, bijectors) return objective(params, batch) @@ -187,9 +186,6 @@ def loss(params, batch): keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) - # Tranform params to unconstrained space: - params = transform(params, bijectors, forward=False) - @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num__and__key): iter_num, key = iter_num__and__key From 6a17542188857fcd78ea3f422ad77d74b02b48b8 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 15:22:05 +0100 Subject: [PATCH 16/66] Parameter state (See comment) All notebooks are updated, except the tensorflow probability and MCMC section of the classification notebook. --- examples/classification.ipynb | 142 ++++++++++++++++------------------ 1 file changed, 68 insertions(+), 74 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 7ff76266..c82b98d7 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -27,7 +27,7 @@ "import distrax as dx\n", "from gpjax.utils import I\n", "import jax.scipy as jsp\n", - "from jaxtyping import Float, Array\n", + "from jaxtyping import f64\n", "\n", "key = jr.PRNGKey(123)" ] @@ -172,12 +172,11 @@ "metadata": {}, "outputs": [], "source": [ - "map_latent_dist = posterior(D, map_estimate)(xtest)\n", + "latent_dist = posterior(D, map_estimate)(xtest)\n", "\n", - "predictive_dist = likelihood(map_latent_dist, map_estimate)\n", + "predictive_dist = likelihood(latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", - "a = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", @@ -197,18 +196,6 @@ "ax.legend()" ] }, - { - "cell_type": "markdown", - "id": "0814299a", - "metadata": {}, - "source": [ - "Here we projected the map estimates $\\hat{\\boldsymbol{f}}$ for the function values $\\boldsymbol{f}$ at the data points $\\boldsymbol{x}$ to get predictions over the whole domain,\n", - "\n", - "\\begin{align}\n", - "p(f(\\cdot)| \\mathcal{D}) \\approx q_{map}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) \\delta(\\boldsymbol{f} - \\hat{\\boldsymbol{f}}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", - "\\end{align}" - ] - }, { "cell_type": "markdown", "id": "219d35b2", @@ -227,18 +214,18 @@ }, "source": [ "## Laplace approximation\n", - "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = p(\\boldsymbol{y}|\\boldsymbol{f}) p(\\boldsymbol{f})$ as the unormalised posterior for function values $\\boldsymbol{f}$ at the datapoints $\\boldsymbol{x}$, we can expand the log of this about the posterior mode $\\hat{\\boldsymbol{f}}$ via a Taylor expansion. This gives:\n", + "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = p(\\boldsymbol{y}|\\boldsymbol{f}) p(\\boldsymbol{f})$ as the unormalised posterior for function values $\\boldsymbol{f}$ at the datapoints $\\boldsymbol{x}$, we can expand the log of this about the posterior mode $\\hat{\\boldsymbol{f}}$ via a Taylor expansion. This gives\n", "\n", "\\begin{align}\n", - "\\log\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) + \\left[\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})|_{\\hat{\\boldsymbol{f}}}\\right]^{T} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\mathcal{O}(\\lVert \\boldsymbol{f} - \\hat{\\boldsymbol{f}} \\rVert^3).\n", + "\\log\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) + \\left[\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})|_{\\hat{\\boldsymbol{f}}}\\right]^{T} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\mathcal{O}(\\lVert \\boldsymbol{f} - \\hat{\\boldsymbol{f}} \\rVert^3)\n", "\\end{align}\n", "\n", "Now since $\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})$ is zero at the mode, this suggests the following approximation\n", "\\begin{align}\n", "\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) \\approx \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) \\exp\\left\\{ \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) \\right\\}\n", - "\\end{align},\n", + "\\end{align}\n", "\n", - "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx q(\\boldsymbol{f}) := \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." + "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." ] }, { @@ -248,27 +235,20 @@ "metadata": {}, "outputs": [], "source": [ - "from gpjax.kernels import gram, cross_covariance\n", - "jitter = 1e-6\n", + "f_map_estimate = posterior(D, map_estimate)(x).mean()\n", "\n", - "# Compute (latent) function value map estimates at training points:\n", - "Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", - "Kxx += I(D.n) * jitter\n", - "Lx = jnp.linalg.cholesky(Kxx)\n", - "f_hat = jnp.matmul(Lx, map_estimate[\"latent\"])\n", - "\n", - "# Negative Hessian, H = -∇²p_tilde(y|f):\n", + "# Negative Hessian:\n", "H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)[\"latent\"][\"latent\"][:,0,:,0]\n", "\n", "# LLᵀ = H\n", + "jitter = 1e-6\n", "L = jnp.linalg.cholesky(H + I(D.n) * jitter)\n", "\n", "# H⁻¹ = H⁻¹ I = (LLᵀ)⁻¹ I = L⁻ᵀL⁻¹ I\n", "L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True)\n", "H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)\n", "\n", - "# p(f|D) ≈ N(f_hat, H⁻¹) \n", - "laplace_approximation = dx.MultivariateNormalFullCovariance(jnp.atleast_1d(f_hat.squeeze()), H_inv)" + "laplace_approximation = dx.MultivariateNormalFullCovariance(f_map_estimate, H_inv)" ] }, { @@ -278,44 +258,70 @@ "lines_to_next_cell": 0 }, "source": [ - "For novel inputs, we must project the above approximating distribution through the Gaussian conditional distribution $p(f(\\cdot)| \\boldsymbol{f})$,\n", - "\n", - "\\begin{align}\n", - "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} (\\mathbf{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", - "\\end{align}\n", - "\n", - "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." + "For novel inputs, we must project the above distribution, which can be achived via the function defined below." ] }, { "cell_type": "code", "execution_count": null, - "id": "0f6c2bcc", - "metadata": {}, + "id": "8b6cb53e", + "metadata": { + "lines_to_next_cell": 0 + }, "outputs": [], "source": [ - "def construct_laplace(test_inputs: Float[Array, \"N D\"]) -> dx.MultivariateNormalFullCovariance:\n", - " \n", - " map_latent_dist = posterior(D, map_estimate)(test_inputs)\n", + "from gpjax.types import Dataset\n", + "from gpjax.kernels import gram, cross_covariance\n", + "\n", + "\n", + "def predict(map_estimate: dict, laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: f64[\"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n", + " \"\"\"Compute the predictive distribution of the Laplace approximation at novel inputs.\n", + "\n", + " Args:\n", + " laplace_at_data (dict): The Laplace approximation at the datapoints.\n", + "\n", + " Returns:\n", + " dx.Distribution: The Laplace approximation at novel inputs.\n", + " \"\"\"\n", + " x, n = train_data.X, train_data.n\n", + "\n", + " t = test_inputs\n", + " n_test = t.shape[0]\n", "\n", - " Kxt = cross_covariance(prior.kernel, x, test_inputs, map_estimate[\"kernel\"])\n", + " mu = laplace_at_data.mean().reshape(-1, 1)\n", + " cov = laplace_at_data.covariance()\n", + "\n", + " Ktt = gram(prior.kernel, t, map_estimate[\"kernel\"])\n", " Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", - " Kxx += I(D.n) * jitter\n", - " Lx = jnp.linalg.cholesky(Kxx)\n", + " Kxt = cross_covariance(prior.kernel, x, t, map_estimate[\"kernel\"])\n", + " μt = prior.mean_function(t, map_estimate[\"mean_function\"])\n", + " μx = prior.mean_function(x, map_estimate[\"mean_function\"])\n", + "\n", + " # Lx Lxᵀ = Kxx\n", + " Lx = jnp.linalg.cholesky(Kxx + I(n) * jitter)\n", + "\n", + " # sqrt sqrtᵀ = Σ\n", + " sqrt = jnp.linalg.cholesky(cov + I(n) * jitter)\n", "\n", - " # Lx⁻¹ Kxt\n", - " Lx_inv_Ktx = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", + " # Lz⁻¹ Kxt\n", + " Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", "\n", " # Kxx⁻¹ Kxt\n", - " Kxx_inv_Ktx = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Ktx, lower=False)\n", + " Kxx_inv_Kxt = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Kxt, lower=False)\n", "\n", - " # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt\n", - " laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)\n", + " # Ktx Kxx⁻¹ sqrt\n", + " Ktx_Kxx_inv_sqrt = jnp.matmul(Kxx_inv_Kxt.T, sqrt)\n", + " \n", + " # μt + Ktx Kxx⁻¹ (μ - μx)\n", + " mean = μt + jnp.matmul(Kxx_inv_Kxt.T, mu - μx)\n", "\n", - " mean = map_latent_dist.mean()\n", - " covariance = map_latent_dist.covariance() + laplace_cov_term\n", + " # Ktt - Ktx Kxx⁻¹ Kxt + Ktx Kxx⁻¹ S Kxx⁻¹ Kxt\n", + " covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + jnp.matmul(Ktx_Kxx_inv_sqrt, Ktx_Kxx_inv_sqrt.T)\n", + " covariance += I(n_test) * jitter\n", "\n", - " return dx.MultivariateNormalFullCovariance(jnp.atleast_1d(mean.squeeze()), covariance)" + " return dx.MultivariateNormalFullCovariance(\n", + " jnp.atleast_1d(mean.squeeze()), covariance\n", + " )" ] }, { @@ -337,8 +343,9 @@ }, "outputs": [], "source": [ - "laplace_latent_dist = construct_laplace(xtest)\n", - "predictive_dist = likelihood(laplace_latent_dist, map_estimate)\n", + "latent_dist = predict(map_estimate, laplace_approximation, D, xtest)\n", + "\n", + "predictive_dist = likelihood(latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", @@ -387,16 +394,6 @@ "We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary." ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "a0bd8213", - "metadata": {}, - "outputs": [], - "source": [ - "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -408,16 +405,13 @@ "num_adapt = 500\n", "num_samples = 500\n", "\n", - "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n", - "mll = posterior.marginal_log_likelihood(D, negative=False)\n", + "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))\n", "\n", - "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", - "\n", - "adapt = blackjax.window_adaptation(blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65)\n", + "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", - "unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", + "last_state, kernel, _ = adapt.run(key, params)\n", + "\n", "\n", "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", " def one_step(state, rng_key):\n", @@ -512,7 +506,7 @@ " ps[\"kernel\"][\"lengthscale\"] = states.position[\"kernel\"][\"lengthscale\"][i]\n", " ps[\"kernel\"][\"variance\"] = states.position[\"kernel\"][\"variance\"][i]\n", " ps[\"latent\"] = states.position[\"latent\"][i, :, :]\n", - " ps = gpx.constrain(ps, bijectors)\n", + " ps = gpx.transform(ps, constrainer)\n", "\n", " latent_dist = posterior(D, ps)(xtest)\n", " predictive_dist = likelihood(latent_dist, ps)\n", From faca79a83d8d02df5181295d10c184326b66345d Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 31 Aug 2022 15:58:46 +0100 Subject: [PATCH 17/66] Update nbs. --- examples/classification.ipynb | 16 +++++++++------- examples/tfp_integration.ipynb | 5 ++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index c82b98d7..355d6878 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -214,16 +214,16 @@ }, "source": [ "## Laplace approximation\n", - "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = p(\\boldsymbol{y}|\\boldsymbol{f}) p(\\boldsymbol{f})$ as the unormalised posterior for function values $\\boldsymbol{f}$ at the datapoints $\\boldsymbol{x}$, we can expand the log of this about the posterior mode $\\hat{\\boldsymbol{f}}$ via a Taylor expansion. This gives\n", + "The Laplace approximation improves uncertainty quantification by incorporating curvature induced by the marginal log-likelihood's Hessian to construct an approximate Gaussian distribution centered on the MAP estimate. Writing $\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = p(\\boldsymbol{y}|\\boldsymbol{f}) p(\\boldsymbol{f})$ as the unormalised posterior for function values $\\boldsymbol{f}$ at the datapoints $\\boldsymbol{x}$, we can expand the log of this about the posterior mode $\\hat{\\boldsymbol{f}}$ via a Taylor expansion. This gives:\n", "\n", "\\begin{align}\n", - "\\log\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) + \\left[\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})|_{\\hat{\\boldsymbol{f}}}\\right]^{T} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\mathcal{O}(\\lVert \\boldsymbol{f} - \\hat{\\boldsymbol{f}} \\rVert^3)\n", + "\\log\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) = \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) + \\left[\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})|_{\\hat{\\boldsymbol{f}}}\\right]^{T} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) + \\mathcal{O}(\\lVert \\boldsymbol{f} - \\hat{\\boldsymbol{f}} \\rVert^3).\n", "\\end{align}\n", "\n", "Now since $\\nabla \\log\\tilde{p}({\\boldsymbol{f}}|\\mathcal{D})$ is zero at the mode, this suggests the following approximation\n", "\\begin{align}\n", "\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) \\approx \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) \\exp\\left\\{ \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) \\right\\}\n", - "\\end{align}\n", + "\\end{align},\n", "\n", "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." ] @@ -405,13 +405,15 @@ "num_adapt = 500\n", "num_samples = 500\n", "\n", - "mll = jax.jit(posterior.marginal_log_likelihood(D, constrainer, negative=False))\n", + "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n", + "mll = posterior.marginal_log_likelihood(D, negative=False)\n", + "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", "\n", "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", - "last_state, kernel, _ = adapt.run(key, params)\n", - "\n", + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", "\n", "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", " def one_step(state, rng_key):\n", @@ -506,7 +508,7 @@ " ps[\"kernel\"][\"lengthscale\"] = states.position[\"kernel\"][\"lengthscale\"][i]\n", " ps[\"kernel\"][\"variance\"] = states.position[\"kernel\"][\"variance\"][i]\n", " ps[\"latent\"] = states.position[\"latent\"][i, :, :]\n", - " ps = gpx.transform(ps, constrainer)\n", + " ps = gpx.constrain(ps, bijectors)\n", "\n", " latent_dist = posterior(D, ps)(xtest)\n", " predictive_dist = likelihood(latent_dist, ps)\n", diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index 0c417b7a..5daab7e5 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -223,7 +223,7 @@ "def build_log_pi(target, mapper_fn):\n", " def array_mll(parameter_array):\n", " parameter_dict = mapper_fn([jnp.array(i) for i in parameter_array])\n", - " parameter_dict = gpx.constrain(parameter_dict, bijectors)\n", + " gpx.constrain(parameter_dict, bijectors)\n", " return target(parameter_dict)\n", "\n", " return array_mll\n", @@ -278,8 +278,7 @@ "outputs": [], "source": [ "unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", - "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))" ] }, { From dea684a31ad174af8b2002280aaa38624b6addd3 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 1 Sep 2022 09:55:24 +0000 Subject: [PATCH 18/66] WIP for constrainers on state --- gpjax/abstractions.py | 33 ++++++++++++++-------------- gpjax/parameters.py | 47 ++++++++++++++++++++++++---------------- tests/test_parameters.py | 9 +++++--- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 0a9522c5..69d82061 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -115,19 +115,18 @@ def fit( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - params, trainables, bijectors = parameter_state.unpack() - def loss(params): - params = trainable_params(params, trainables) - params = constrain(params, bijectors) + parameter_state = trainable_params(parameter_state) + parameter_state = constrain(parameter_state) + params = parameter_state.params return objective(params) iter_nums = jnp.arange(n_iters) # Tranform params to unconstrained space: - params = unconstrain(params, bijectors) + parameter_state = unconstrain(parameter_state) - opt_state = optax_optim.init(params) + opt_state = optax_optim.init(parameter_state.params) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): @@ -141,9 +140,9 @@ def step(carry, iter_num): (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) # Tranform params to constrained space: - params = constrain(params, bijectors) - - inf_state = InferenceState(params=params, history=history) + parameter_state.params = params + params = constrain(parameter_state) + inf_state = InferenceState(params=parameter_state.params, history=history) return inf_state @@ -173,15 +172,14 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ - params, trainables, bijectors = parameter_state.unpack() - def loss(params, batch): - params = trainable_params(params, trainables) - params = constrain(params, bijectors) + parameter_state = trainable_params(parameter_state) + parameter_state = constrain(parameter_state) + params = parameter_state.params return objective(params, batch) - params = unconstrain(params, bijectors) - + parameter_state = unconstrain(parameter_state) + params = parameter_state.params opt_state = optax_optim.init(params) keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) @@ -202,8 +200,9 @@ def step(carry, iter_num__and__key): (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - params = constrain(params, bijectors) - inf_state = InferenceState(params=params, history=history) + parameter_state.params = params + parameter_state = constrain(parameter_state) + inf_state = InferenceState(params=parameter_state.params, history=history) return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 62cc40ec..33b30468 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -139,38 +139,43 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: - """Transform the parameters to the constrained space for corresponding bijectors. +def constrain(state: ParameterState) -> ParameterState: + """Transform the parameters to a constrained space using the corresponding set of bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been constrained. """ + params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.forward(param) - - return jax.tree_util.tree_map(map, params, bijectors) + transformed_params = jax.tree_util.tree_map(map, params, bijectors) + return ParameterState( + params=transformed_params, + trainables=state.trainables, + bijectors=bijectors, + ) -def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: - """Transform the parameters to the unconstrained space for corresponding bijectors. +def unconstrain(state: ParameterState) -> ParameterState: + """Transform the parameters to a unconstrained space using the corresponding set of bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been unconstrained. """ - + params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.inverse(param) - - return jax.tree_util.tree_map(map, params, bijectors) + transformed_params = jax.tree_util.tree_map(map, params, bijectors) + return ParameterState( + params=transformed_params, + trainables=state.trainables, + bijectors=bijectors, + ) ################################ @@ -270,8 +275,12 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict): return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: +def trainable_params(state: ParameterState) -> ParameterState: """Stop the gradients flowing through parameters whose trainable status is False""" - return jax.tree_util.tree_map( + params, trainables = state.params, state.trainables + trainable_params = jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables ) + return ParameterState( + params=trainable_params, trainables=trainables, bijectors=state.bijectors + ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 18fd2e3e..b1c5c84e 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -228,18 +228,21 @@ def test_prior_checks(latent_prior): @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) def test_output(num_datapoints, likelihood): posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() + state = initialise(posterior, jr.PRNGKey(123)) + params, _, bijectors = state.unpack() assert isinstance(bijectors, dict) for k, v1, v2 in recursive_items(bijectors, bijectors): assert isinstance(v1.forward, tp.Callable) assert isinstance(v2.inverse, tp.Callable) - unconstrained_params = unconstrain(params, bijectors) + unconstrained_state = unconstrain(state) + unconstrained_params = unconstrained_state.params assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_params = constrain(unconstrained_params, bijectors) + backconstrained_state = constrain(unconstrained_state) + backconstrained_params = backconstrained_state.params for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype From d042a7beae24dc4c0ca8d16478f14cdd84fc79d9 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 2 Sep 2022 19:32:42 +0100 Subject: [PATCH 19/66] Revert "WIP for constrainers on state" This reverts commit 7d9ed4d8a392025c53a3fe1f7348eefe19504994. --- gpjax/abstractions.py | 33 ++++++++++++++-------------- gpjax/parameters.py | 47 ++++++++++++++++------------------------ tests/test_parameters.py | 9 +++----- 3 files changed, 39 insertions(+), 50 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 69d82061..0a9522c5 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -115,18 +115,19 @@ def fit( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params): - parameter_state = trainable_params(parameter_state) - parameter_state = constrain(parameter_state) - params = parameter_state.params + params = trainable_params(params, trainables) + params = constrain(params, bijectors) return objective(params) iter_nums = jnp.arange(n_iters) # Tranform params to unconstrained space: - parameter_state = unconstrain(parameter_state) + params = unconstrain(params, bijectors) - opt_state = optax_optim.init(parameter_state.params) + opt_state = optax_optim.init(params) @progress_bar_scan(n_iters, log_rate) def step(carry, iter_num): @@ -140,9 +141,9 @@ def step(carry, iter_num): (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) # Tranform params to constrained space: - parameter_state.params = params - params = constrain(parameter_state) - inf_state = InferenceState(params=parameter_state.params, history=history) + params = constrain(params, bijectors) + + inf_state = InferenceState(params=params, history=history) return inf_state @@ -172,14 +173,15 @@ def fit_batches( InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ + params, trainables, bijectors = parameter_state.unpack() + def loss(params, batch): - parameter_state = trainable_params(parameter_state) - parameter_state = constrain(parameter_state) - params = parameter_state.params + params = trainable_params(params, trainables) + params = constrain(params, bijectors) return objective(params, batch) - parameter_state = unconstrain(parameter_state) - params = parameter_state.params + params = unconstrain(params, bijectors) + opt_state = optax_optim.init(params) keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) @@ -200,9 +202,8 @@ def step(carry, iter_num__and__key): (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) - parameter_state.params = params - parameter_state = constrain(parameter_state) - inf_state = InferenceState(params=parameter_state.params, history=history) + params = constrain(params, bijectors) + inf_state = InferenceState(params=params, history=history) return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 33b30468..62cc40ec 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -139,43 +139,38 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def constrain(state: ParameterState) -> ParameterState: - """Transform the parameters to a constrained space using the corresponding set of bijectors. +def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the constrained space for corresponding bijectors. Args: - state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. + params (tp.Dict): The parameters that are to be transformed. + transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been constrained. + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - params, bijectors = state.params, state.bijectors map = lambda param, trans: trans.forward(param) - transformed_params = jax.tree_util.tree_map(map, params, bijectors) - return ParameterState( - params=transformed_params, - trainables=state.trainables, - bijectors=bijectors, - ) + return jax.tree_util.tree_map(map, params, bijectors) -def unconstrain(state: ParameterState) -> ParameterState: - """Transform the parameters to a unconstrained space using the corresponding set of bijectors. + +def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: + """Transform the parameters to the unconstrained space for corresponding bijectors. Args: - state (ParameterState): The state object containing the parameters and corresponding bijectors that are to be transformed. + params (tp.Dict): The parameters that are to be transformed. + transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - ParameterState: A transformed parameter set. The state object is equal in structure to the input state, the only difference being that the parameters have now been unconstrained. + tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ - params, bijectors = state.params, state.bijectors + map = lambda param, trans: trans.inverse(param) - transformed_params = jax.tree_util.tree_map(map, params, bijectors) - return ParameterState( - params=transformed_params, - trainables=state.trainables, - bijectors=bijectors, - ) + + return jax.tree_util.tree_map(map, params, bijectors) ################################ @@ -275,12 +270,8 @@ def stop_grad(param: tp.Dict, trainable: tp.Dict): return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(state: ParameterState) -> ParameterState: +def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: """Stop the gradients flowing through parameters whose trainable status is False""" - params, trainables = state.params, state.trainables - trainable_params = jax.tree_util.tree_map( + return jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables ) - return ParameterState( - params=trainable_params, trainables=trainables, bijectors=state.bijectors - ) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index b1c5c84e..18fd2e3e 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -228,21 +228,18 @@ def test_prior_checks(latent_prior): @pytest.mark.parametrize("likelihood", [Gaussian, Bernoulli]) def test_output(num_datapoints, likelihood): posterior = Prior(kernel=RBF()) * likelihood(num_datapoints=num_datapoints) - state = initialise(posterior, jr.PRNGKey(123)) - params, _, bijectors = state.unpack() + params, _, bijectors = initialise(posterior, jr.PRNGKey(123)).unpack() assert isinstance(bijectors, dict) for k, v1, v2 in recursive_items(bijectors, bijectors): assert isinstance(v1.forward, tp.Callable) assert isinstance(v2.inverse, tp.Callable) - unconstrained_state = unconstrain(state) - unconstrained_params = unconstrained_state.params + unconstrained_params = unconstrain(params, bijectors) assert ( unconstrained_params["kernel"]["lengthscale"] != params["kernel"]["lengthscale"] ) - backconstrained_state = constrain(unconstrained_state) - backconstrained_params = backconstrained_state.params + backconstrained_params = constrain(unconstrained_params, bijectors) for k, v1, v2 in recursive_items(params, unconstrained_params): assert v1.dtype == v2.dtype From feb95c946a5811806303db3607cc8dfe4dd46ea7 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 4 Sep 2022 13:19:34 +0100 Subject: [PATCH 20/66] Test MCMC docs. --- examples/classification.ipynb | 5 +++-- examples/tfp_integration.ipynb | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 355d6878..7fc6ec71 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -412,8 +412,9 @@ "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", - "unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", + "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "#last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", + "last_state, kernel, _ = adapt.run(key, params)\n", "\n", "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", " def one_step(state, rng_key):\n", diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index 5daab7e5..a9ede14c 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -277,8 +277,9 @@ "metadata": {}, "outputs": [], "source": [ - "unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))" + "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "#states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" ] }, { From d5e989b03bfc063494a67141541630da6d2d9705 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Mon, 5 Sep 2022 11:02:40 +0100 Subject: [PATCH 21/66] Update classification.ipynb --- examples/classification.ipynb | 128 ++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 7fc6ec71..ca8d2eaf 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -172,11 +172,12 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dist = posterior(D, map_estimate)(xtest)\n", + "map_latent_dist = posterior(D, map_estimate)(xtest)\n", "\n", - "predictive_dist = likelihood(latent_dist, map_estimate)\n", + "predictive_dist = likelihood(map_latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", + "a = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", @@ -196,6 +197,18 @@ "ax.legend()" ] }, + { + "cell_type": "markdown", + "id": "0814299a", + "metadata": {}, + "source": [ + "Here we projected the map estimates $\\hat{\\boldsymbol{f}}$ for the function values $\\boldsymbol{f}$ at the data points $\\boldsymbol{x}$ to get predictions over the whole domain,\n", + "\n", + "\\begin{align}\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{map}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) \\delta(\\boldsymbol{f} - \\hat{\\boldsymbol{f}}) d \\boldsymbol{f} = \\mathcal{N}(\\bold{K}_{\\boldsymbol{(\\cdot)x}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\bold{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "\\end{align}" + ] + }, { "cell_type": "markdown", "id": "219d35b2", @@ -225,7 +238,7 @@ "\\tilde{p}(\\boldsymbol{f}|\\mathcal{D}) \\approx \\log\\tilde{p}(\\hat{\\boldsymbol{f}}|\\mathcal{D}) \\exp\\left\\{ \\frac{1}{2} (\\boldsymbol{f}-\\hat{\\boldsymbol{f}})^{T} \\left[-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} \\right] (\\boldsymbol{f}-\\hat{\\boldsymbol{f}}) \\right\\}\n", "\\end{align},\n", "\n", - "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." + "that we identify as a Gaussian distribution, $p(\\boldsymbol{f}| \\mathcal{D}) \\approx q(\\boldsymbol{f}) := \\mathcal{N}(\\hat{\\boldsymbol{f}}, [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} )$. Since the negative Hessian is positive definite, we can use the Cholesky decomposition to obtain the covariance matrix of the Laplace approximation at the datapoints below." ] }, { @@ -235,20 +248,27 @@ "metadata": {}, "outputs": [], "source": [ - "f_map_estimate = posterior(D, map_estimate)(x).mean()\n", + "from gpjax.kernels import gram, cross_covariance\n", + "jitter = 1e-6\n", + "\n", + "# Compute (latent) function value map estimates at training points:\n", + "Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", + "Kxx += I(D.n) * jitter\n", + "Lx = jnp.linalg.cholesky(Kxx)\n", + "f_hat = jnp.matmul(Lx, map_estimate[\"latent\"])\n", "\n", - "# Negative Hessian:\n", + "# Negative Hessian, H = -∇²p_tilde(y|f):\n", "H = jax.jacfwd(jax.jacrev(negative_mll))(map_estimate)[\"latent\"][\"latent\"][:,0,:,0]\n", "\n", "# LLᵀ = H\n", - "jitter = 1e-6\n", "L = jnp.linalg.cholesky(H + I(D.n) * jitter)\n", "\n", "# H⁻¹ = H⁻¹ I = (LLᵀ)⁻¹ I = L⁻ᵀL⁻¹ I\n", "L_inv = jsp.linalg.solve_triangular(L, I(D.n), lower=True)\n", "H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)\n", "\n", - "laplace_approximation = dx.MultivariateNormalFullCovariance(f_map_estimate, H_inv)" + "# p(f|D) ≈ N(f_hat, H⁻¹) \n", + "laplace_approximation = dx.MultivariateNormalFullCovariance(jnp.atleast_1d(f_hat.squeeze()), H_inv)" ] }, { @@ -258,70 +278,49 @@ "lines_to_next_cell": 0 }, "source": [ - "For novel inputs, we must project the above distribution, which can be achived via the function defined below." + "For novel inputs, we must project the above approximating distribution through the Gaussian conditional distribution $p(f(\\cdot)| \\boldsymbol{f})$,\n", + "\n", + "\\begin{align}\n", + "\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\bold{K}_{\\boldsymbol{(\\cdot)x}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\bold{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} (\\bold{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "\n", + "\\end{align}\n", + "\n", + "\n", + "\n", + "\n", + "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." ] }, { "cell_type": "code", "execution_count": null, - "id": "8b6cb53e", - "metadata": { - "lines_to_next_cell": 0 - }, + "id": "0f6c2bcc", + "metadata": {}, "outputs": [], "source": [ - "from gpjax.types import Dataset\n", - "from gpjax.kernels import gram, cross_covariance\n", - "\n", - "\n", - "def predict(map_estimate: dict, laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: f64[\"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n", - " \"\"\"Compute the predictive distribution of the Laplace approximation at novel inputs.\n", - "\n", - " Args:\n", - " laplace_at_data (dict): The Laplace approximation at the datapoints.\n", - "\n", - " Returns:\n", - " dx.Distribution: The Laplace approximation at novel inputs.\n", - " \"\"\"\n", - " x, n = train_data.X, train_data.n\n", - "\n", - " t = test_inputs\n", - " n_test = t.shape[0]\n", - "\n", - " mu = laplace_at_data.mean().reshape(-1, 1)\n", - " cov = laplace_at_data.covariance()\n", + "def construct_laplace(test_inputs: f64[\"N D\"]) -> dx.MultivariateNormalFullCovariance:\n", + " \n", + " map_latent_dist = posterior(D, map_estimate)(test_inputs)\n", "\n", - " Ktt = gram(prior.kernel, t, map_estimate[\"kernel\"])\n", + " Kxt = cross_covariance(prior.kernel, x, test_inputs, map_estimate[\"kernel\"])\n", " Kxx = gram(prior.kernel, x, map_estimate[\"kernel\"])\n", - " Kxt = cross_covariance(prior.kernel, x, t, map_estimate[\"kernel\"])\n", - " μt = prior.mean_function(t, map_estimate[\"mean_function\"])\n", - " μx = prior.mean_function(x, map_estimate[\"mean_function\"])\n", - "\n", - " # Lx Lxᵀ = Kxx\n", - " Lx = jnp.linalg.cholesky(Kxx + I(n) * jitter)\n", + " Kxx += I(D.n) * jitter\n", + " Lx = jnp.linalg.cholesky(Kxx)\n", "\n", - " # sqrt sqrtᵀ = Σ\n", - " sqrt = jnp.linalg.cholesky(cov + I(n) * jitter)\n", - "\n", - " # Lz⁻¹ Kxt\n", - " Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", + " # Lx⁻¹ Kxt\n", + " Lx_inv_Ktx = jsp.linalg.solve_triangular(Lx, Kxt, lower=True)\n", "\n", " # Kxx⁻¹ Kxt\n", - " Kxx_inv_Kxt = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Kxt, lower=False)\n", + " Kxx_inv_Ktx = jsp.linalg.solve_triangular(Lx.T, Lx_inv_Ktx, lower=False)\n", "\n", - " # Ktx Kxx⁻¹ sqrt\n", - " Ktx_Kxx_inv_sqrt = jnp.matmul(Kxx_inv_Kxt.T, sqrt)\n", - " \n", - " # μt + Ktx Kxx⁻¹ (μ - μx)\n", - " mean = μt + jnp.matmul(Kxx_inv_Kxt.T, mu - μx)\n", + " # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt\n", + " laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)\n", "\n", - " # Ktt - Ktx Kxx⁻¹ Kxt + Ktx Kxx⁻¹ S Kxx⁻¹ Kxt\n", - " covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) + jnp.matmul(Ktx_Kxx_inv_sqrt, Ktx_Kxx_inv_sqrt.T)\n", - " covariance += I(n_test) * jitter\n", + " mean = map_latent_dist.mean()\n", + " covariance = map_latent_dist.covariance() + laplace_cov_term\n", "\n", - " return dx.MultivariateNormalFullCovariance(\n", - " jnp.atleast_1d(mean.squeeze()), covariance\n", - " )" + " return dx.MultivariateNormalFullCovariance(jnp.atleast_1d(mean.squeeze()), covariance)" ] }, { @@ -343,9 +342,8 @@ }, "outputs": [], "source": [ - "latent_dist = predict(map_estimate, laplace_approximation, D, xtest)\n", - "\n", - "predictive_dist = likelihood(latent_dist, map_estimate)\n", + "laplace_latent_dist = construct_laplace(xtest)\n", + "predictive_dist = likelihood(laplace_latent_dist, map_estimate)\n", "\n", "predictive_mean = predictive_dist.mean()\n", "predictive_std = predictive_dist.stddev()\n", @@ -394,6 +392,16 @@ "We begin by generating _sensible_ initial positions for our sampler before defining an inference loop and sampling 500 values from our Markov chain. In practice, drawing more samples will be necessary." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "a0bd8213", + "metadata": {}, + "outputs": [], + "source": [ + "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -407,7 +415,7 @@ "\n", "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n", "mll = posterior.marginal_log_likelihood(D, negative=False)\n", - "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", + "unconstrained_mll = jax.jit(lambda params: mll(gpx.unconstrain(params, bijectors)))\n", "\n", "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", From c358fc1ff36f1ab9b30a0d6af0998e6a0ae7d21d Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 7 Sep 2022 08:33:58 +0100 Subject: [PATCH 22/66] Update classification.ipynb --- examples/classification.ipynb | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index ca8d2eaf..6377702e 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -205,7 +205,7 @@ "Here we projected the map estimates $\\hat{\\boldsymbol{f}}$ for the function values $\\boldsymbol{f}$ at the data points $\\boldsymbol{x}$ to get predictions over the whole domain,\n", "\n", "\\begin{align}\n", - "p(f(\\cdot)| \\mathcal{D}) \\approx q_{map}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) \\delta(\\boldsymbol{f} - \\hat{\\boldsymbol{f}}) d \\boldsymbol{f} = \\mathcal{N}(\\bold{K}_{\\boldsymbol{(\\cdot)x}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\bold{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{map}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) \\delta(\\boldsymbol{f} - \\hat{\\boldsymbol{f}}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", "\\end{align}" ] }, @@ -282,14 +282,14 @@ "\n", "\\begin{align}\n", "\n", - "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\bold{K}_{\\boldsymbol{(\\cdot)x}} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\bold{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} (\\bold{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", + "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} (\\mathbf{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", "\n", "\\end{align}\n", "\n", "\n", "\n", "\n", - "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\bold{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\bold{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\bold{K}_{\\boldsymbol{xx}}^{-1} \\bold{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." + "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." ] }, { From 3f5465b50f2dff420e8f97eeea5c1fdae812d2b0 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 7 Sep 2022 08:42:08 +0100 Subject: [PATCH 23/66] Fix MCMC? --- examples/classification.ipynb | 12 +++--------- examples/tfp_integration.ipynb | 6 +++--- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 6377702e..23a3a907 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -281,14 +281,9 @@ "For novel inputs, we must project the above approximating distribution through the Gaussian conditional distribution $p(f(\\cdot)| \\boldsymbol{f})$,\n", "\n", "\\begin{align}\n", - "\n", "p(f(\\cdot)| \\mathcal{D}) \\approx q_{Laplace}(f(\\cdot)) := \\int p(f(\\cdot)| \\boldsymbol{f}) q(\\boldsymbol{f}) d \\boldsymbol{f} = \\mathcal{N}(\\mathbf{K}_{\\boldsymbol{(\\cdot)x}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\hat{\\boldsymbol{f}}, \\mathbf{K}_{\\boldsymbol{(\\cdot, \\cdot)}} - \\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} (\\mathbf{K}_{\\boldsymbol{xx}} - [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1}) \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}).\n", - "\n", "\\end{align}\n", "\n", - "\n", - "\n", - "\n", "This is the same approximate distribution $q_{map}(f(\\cdot))$, but we have pertubed the covariance by a curvature term of $\\mathbf{K}_{\\boldsymbol{(\\cdot)\\boldsymbol{x}}} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} [-\\nabla^2 \\tilde{p}(\\boldsymbol{y}|\\boldsymbol{f})|_{\\hat{\\boldsymbol{f}}} ]^{-1} \\mathbf{K}_{\\boldsymbol{xx}}^{-1} \\mathbf{K}_{\\boldsymbol{\\boldsymbol{x}(\\cdot)}}$. We take the latent distribution computed in the previous section and add this term to the covariance to construct $q_{Laplace}(f(\\cdot))$." ] }, @@ -415,14 +410,13 @@ "\n", "params, trainables, bijectors = gpx.initialise(posterior, key).unpack()\n", "mll = posterior.marginal_log_likelihood(D, negative=False)\n", - "unconstrained_mll = jax.jit(lambda params: mll(gpx.unconstrain(params, bijectors)))\n", + "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", "\n", "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", - "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "#last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", - "last_state, kernel, _ = adapt.run(key, params)\n", + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "last_state, kernel, _ = adapt.run(key, unconstrained_params)\n", "\n", "def inference_loop(rng_key, kernel, initial_state, num_samples):\n", " def one_step(state, rng_key):\n", diff --git a/examples/tfp_integration.ipynb b/examples/tfp_integration.ipynb index a9ede14c..0c417b7a 100644 --- a/examples/tfp_integration.ipynb +++ b/examples/tfp_integration.ipynb @@ -223,7 +223,7 @@ "def build_log_pi(target, mapper_fn):\n", " def array_mll(parameter_array):\n", " parameter_dict = mapper_fn([jnp.array(i) for i in parameter_array])\n", - " gpx.constrain(parameter_dict, bijectors)\n", + " parameter_dict = gpx.constrain(parameter_dict, bijectors)\n", " return target(parameter_dict)\n", "\n", " return array_mll\n", @@ -277,8 +277,8 @@ "metadata": {}, "outputs": [], "source": [ - "#unconstrained_params = gpx.unconstrain(params, bijectors)\n", - "#states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", + "unconstrained_params = gpx.unconstrain(params, bijectors)\n", + "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(unconstrained_params)))\n", "states, log_probs = jax.jit(run_chain)(key, jnp.array(dict_to_array(params)))" ] }, From 8dbfc44969b53a62387f75bf1af516ce36f266a1 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 7 Sep 2022 09:43:20 +0100 Subject: [PATCH 24/66] Update classification.ipynb --- examples/classification.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 23a3a907..94cf9c28 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -412,7 +412,7 @@ "mll = posterior.marginal_log_likelihood(D, negative=False)\n", "unconstrained_mll = jax.jit(lambda params: mll(gpx.constrain(params, bijectors)))\n", "\n", - "adapt = blackjax.window_adaptation(blackjax.nuts, mll, num_adapt, target_acceptance_rate=0.65)\n", + "adapt = blackjax.window_adaptation(blackjax.nuts, unconstrained_mll, num_adapt, target_acceptance_rate=0.65)\n", "\n", "# Initialise the chain\n", "unconstrained_params = gpx.unconstrain(params, bijectors)\n", From dda0c6dcf2b944ea791d7a2d35c9907608d579da Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 1 Jun 2022 14:57:05 +0100 Subject: [PATCH 25/66] Natural gradients. Added Natural Variational Gaussian Family. NEED TO WRITE TESTS. The variational tests for variational families are a mess --- much of the testing is done in sparse_gps.py. In order to write these, it might be worth fixing these too. --- gpjax/__init__.py | 2 + gpjax/variational_families.py | 122 ++++++++++++++++++++++++++++- tests/test_variational_families.py | 30 +++++++ 3 files changed, 151 insertions(+), 3 deletions(-) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index a897e610..807d7701 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -25,6 +25,8 @@ CollapsedVariationalGaussian, VariationalGaussian, WhitenedVariationalGaussian, + NaturalVariationalGaussian, + ExpectationVariationalGaussian, ) from .variational_inference import CollapsedVI, StochasticVI diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 1ed4682c..5de441a8 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -257,6 +257,124 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: ) return predict_fn + + +@dataclass +class NaturalVariationalGaussian(AbstractVariationalFamily): + """The variational Gaussian family of probability distributions.""" + prior: Prior + inducing_inputs: Array + name: str = "Natural Gaussian" + natural_vector: Optional[Array] = None + natural_matrix: Optional[Array] = None + jitter: Optional[float] = DEFAULT_JITTER + + def __post_init__(self): + """Initialise the variational Gaussian distribution.""" + self.num_inducing = self.inducing_inputs.shape[0] + add_parameter("inducing_inputs", Identity) + + m = self.num_inducing + + if self.natural_vector is None: + self.natural_vector = jnp.zeros((m, 1)) + add_parameter("natural_vector", Identity) + + if self.natural_matrix is None: + self.natural_matrix = -.5 * I(m) + add_parameter("natural_matrix", Identity) + + @property + def params(self) -> Dict: + """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" + return concat_dictionaries( + self.prior.params, { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "natural_vector": self.natural_vector, + "natural_matrix": self.natural_matrix} + } + ) + + def prior_kl(self, params: Dict) -> Array: + """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. + + Args: + params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. + + Returns: + Array: The KL-divergence between our variational approximation and the GP prior. + """ + natural_vector = params["variational_family"]["natural_vector"] + natural_covariance = params["variational_family"]["natural_covariance"] + z = params["variational_family"]["inducing_inputs"] + m = self.num_inducing + + S_inv = -2 * natural_covariance + S_inv += I(m) * self.jitter + L_inv = jnp.linalg.cholesky(S_inv) + L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + + S = jnp.matmul(L, L.T) + mu = jnp.matmul(S, natural_vector) + + μz = self.prior.mean_function(z, params["mean_function"]) + Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz += I(m) * self.jitter + Lz = jnp.linalg.cholesky(Kzz) + + qu = dx.MultivariateNormalTri(mu.squeeze(), L) + pu = dx.MultivariateNormalTri(μz.squeeze(), Lz) + + return qu.kl_divergence(pu) + + def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: + """Compute the predictive distribution of the GP at the test inputs. + + Args: + params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + + Returns: + Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + """ + natural_vector = params["variational_family"]["natural_vector"] + natural_covariance = params["variational_family"]["natural_covariance"] + z = params["variational_family"]["inducing_inputs"] + m = self.num_inducing + + S_inv = -2 * natural_covariance + S_inv += I(m) * self.jitter + L_inv = jnp.linalg.cholesky(S_inv) + L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + + S = jnp.matmul(L, L.T) + mu = jnp.matmul(S, natural_vector) + + Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz += I(m) * self.jitter + Lz = jnp.linalg.cholesky(Kzz) + μz = self.prior.mean_function(z, params["mean_function"]) + + def predict_fn(test_inputs: Array) -> dx.Distribution: + t = test_inputs + Ktt = gram(self.prior.kernel, t, params["kernel"]) + Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) + μt = self.prior.mean_function(t, params["mean_function"]) + A = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + B = jsp.linalg.solve_triangular(Lz.T, A, lower=False) + V = jnp.matmul(B.T, L) + + mean = μt + jnp.matmul(B.T, mu - μz) + covariance = Ktt - jnp.matmul(A.T, A) + jnp.matmul(V, V.T) + + return dx.MultivariateNormalFullCovariance( + jnp.atleast_1d(mean.squeeze()), covariance + ) + + return predict_fn + + + @dataclass @@ -295,10 +413,8 @@ def predict( self, train_data: Dataset, params: dict ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. - Args: params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. - Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ @@ -364,4 +480,4 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: jnp.atleast_1d(mean.squeeze()), covariance ) - return predict_fn + return predict_fn \ No newline at end of file diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 21c3cfde..7f265b0a 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -1,6 +1,7 @@ import typing as tp import distrax as dx + import jax.numpy as jnp import jax.random as jr import pytest @@ -163,3 +164,32 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ assert isinstance(sigma, jnp.ndarray) assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) +@pytest.mark.parametrize("n_inducing", [1, 10, 20]) +def test_natural_variational_gaussian_params(n_inducing): + prior = gpx.Prior(kernel=gpx.RBF()) + inducing_points = jnp.linspace(-3.0, 3.0, n_inducing).reshape(-1, 1) + variational_family = gpx.variational_families.NaturalVariationalGaussian( + prior=prior, + inducing_inputs=inducing_points + ) + + params = variational_family.params + assert isinstance(params, dict) + assert "inducing_inputs" in params["variational_family"].keys() + assert "natural_vector" in params["variational_family"].keys() + assert "natural_matrix" in params["variational_family"].keys() + + assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) + assert params["variational_family"]["natural_vector"].shape == (n_inducing, 1) + assert params["variational_family"]["natural_matrix"].shape == (n_inducing, n_inducing) + + assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["natural_vector"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["natural_matrix"], jnp.DeviceArray) + + params = gpx.config.get_defaults() + assert "natural_vector" in params["transformations"].keys() + assert "natural_matrix" in params["transformations"].keys() + + assert (variational_family.natural_matrix == -.5 * jnp.eye(n_inducing)).all() + assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() From dbaf04b0d4fce580e624f23bdf2a4801b4652848 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 1 Jun 2022 19:57:31 +0100 Subject: [PATCH 26/66] Update tests + fix minor bugs. --- gpjax/variational_families.py | 12 +++---- tests/test_variational_families.py | 51 ++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 5de441a8..e5740997 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -306,11 +306,11 @@ def prior_kl(self, params: Dict) -> Array: Array: The KL-divergence between our variational approximation and the GP prior. """ natural_vector = params["variational_family"]["natural_vector"] - natural_covariance = params["variational_family"]["natural_covariance"] + natural_matrix = params["variational_family"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - S_inv = -2 * natural_covariance + S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) @@ -323,8 +323,8 @@ def prior_kl(self, params: Dict) -> Array: Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - qu = dx.MultivariateNormalTri(mu.squeeze(), L) - pu = dx.MultivariateNormalTri(μz.squeeze(), Lz) + qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), L) + pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) return qu.kl_divergence(pu) @@ -338,11 +338,11 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ natural_vector = params["variational_family"]["natural_vector"] - natural_covariance = params["variational_family"]["natural_covariance"] + natural_matrix = params["variational_family"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - S_inv = -2 * natural_covariance + S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 7f265b0a..8c61d372 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -165,14 +165,38 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) @pytest.mark.parametrize("n_inducing", [1, 10, 20]) -def test_natural_variational_gaussian_params(n_inducing): +def test_natural_variational_gaussian(n_inducing, n_test): prior = gpx.Prior(kernel=gpx.RBF()) - inducing_points = jnp.linspace(-3.0, 3.0, n_inducing).reshape(-1, 1) + + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) + test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) + + variational_family = gpx.variational_families.NaturalVariationalGaussian( prior=prior, - inducing_inputs=inducing_points + inducing_inputs=inducing_inputs + ) + + # Test init + assert variational_family.num_inducing == n_inducing + + assert jnp.sum(variational_family.natural_vector) == 0.0 + assert variational_family.natural_vector.shape == (n_inducing, 1) + + assert variational_family.natural_matrix.shape == ( + n_inducing, + n_inducing, ) + assert jnp.all(jnp.diag(variational_family.natural_matrix) == -.5) + + params = gpx.config.get_defaults() + assert "variational_root_covariance" in params["transformations"].keys() + assert "variational_mean" in params["transformations"].keys() + assert (variational_family.natural_matrix == -.5 * jnp.eye(n_inducing)).all() + assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() + + # params params = variational_family.params assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() @@ -193,3 +217,24 @@ def test_natural_variational_gaussian_params(n_inducing): assert (variational_family.natural_matrix == -.5 * jnp.eye(n_inducing)).all() assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() + + + #Test KL + params = variational_family.params + kl = variational_family.prior_kl(params) + assert isinstance(kl, jnp.ndarray) + + # Test predictions + predictive_dist_fn = variational_family(params) + assert isinstance(predictive_dist_fn, tp.Callable) + + predictive_dist = predictive_dist_fn(test_inputs) + assert isinstance(predictive_dist, dx.Distribution) + + mu = predictive_dist.mean() + sigma = predictive_dist.covariance() + + assert isinstance(mu, jnp.ndarray) + assert isinstance(sigma, jnp.ndarray) + assert mu.shape == (n_test,) + assert sigma.shape == (n_test, n_test) From 048b0cc3f3f8e8c406bbb37e40a8014bf1861766 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 1 Jun 2022 20:13:45 +0100 Subject: [PATCH 27/66] Update variational_families.py --- gpjax/variational_families.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e5740997..a19f6261 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -313,10 +313,13 @@ def prior_kl(self, params: Dict) -> Array: S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) - L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + B = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - S = jnp.matmul(L, L.T) + S = jnp.matmul(B.T, B) mu = jnp.matmul(S, natural_vector) + + S += I(m) * self.jitter + L = jnp.linalg.cholesky(S) μz = self.prior.mean_function(z, params["mean_function"]) Kzz = gram(self.prior.kernel, z, params["kernel"]) @@ -345,11 +348,14 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) - L = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + B = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - S = jnp.matmul(L, L.T) + S = jnp.matmul(B.T, B) mu = jnp.matmul(S, natural_vector) - + + S += I(m) * self.jitter + L = jnp.linalg.cholesky(S) + Kzz = gram(self.prior.kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) From 40f9dd2026fe321f4ee1c5492d7aa2ae817e6382 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 1 Jun 2022 20:42:08 +0100 Subject: [PATCH 28/66] Update variational_families.py --- gpjax/variational_families.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index a19f6261..fabb0777 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -319,14 +319,14 @@ def prior_kl(self, params: Dict) -> Array: mu = jnp.matmul(S, natural_vector) S += I(m) * self.jitter - L = jnp.linalg.cholesky(S) + sqrt = jnp.linalg.cholesky(S) μz = self.prior.mean_function(z, params["mean_function"]) Kzz = gram(self.prior.kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) - qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), L) + qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) return qu.kl_divergence(pu) @@ -348,14 +348,11 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) - B = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + C = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - S = jnp.matmul(B.T, B) + S = jnp.matmul(C.T, C) mu = jnp.matmul(S, natural_vector) - S += I(m) * self.jitter - L = jnp.linalg.cholesky(S) - Kzz = gram(self.prior.kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = jnp.linalg.cholesky(Kzz) @@ -368,7 +365,7 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: μt = self.prior.mean_function(t, params["mean_function"]) A = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) B = jsp.linalg.solve_triangular(Lz.T, A, lower=False) - V = jnp.matmul(B.T, L) + V = jnp.matmul(B.T, C.T) mean = μt + jnp.matmul(B.T, mu - μz) covariance = Ktt - jnp.matmul(A.T, A) + jnp.matmul(V, V.T) From a0b5ab27b9b1fddc3bf52126608efc306c01974b Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 1 Jun 2022 20:43:18 +0100 Subject: [PATCH 29/66] Update variational_families.py --- gpjax/variational_families.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index fabb0777..e566fef6 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -313,9 +313,9 @@ def prior_kl(self, params: Dict) -> Array: S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) - B = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + C = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - S = jnp.matmul(B.T, B) + S = jnp.matmul(C.T, C) mu = jnp.matmul(S, natural_vector) S += I(m) * self.jitter From 426e6cb626e56ef187d67e7e92fb61b0576f8e95 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 8 Jun 2022 15:10:30 +0100 Subject: [PATCH 30/66] Update variational_families.py Add maths description and add expectation parameteriation. --- gpjax/variational_families.py | 156 ++++++++++++++++++++++++++++++++-- 1 file changed, 148 insertions(+), 8 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e566fef6..d47a04c5 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -261,7 +261,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: @dataclass class NaturalVariationalGaussian(AbstractVariationalFamily): - """The variational Gaussian family of probability distributions.""" + """The natural variational Gaussian family of probability distributions.""" prior: Prior inducing_inputs: Array name: str = "Natural Gaussian" @@ -345,12 +345,20 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: z = params["variational_family"]["inducing_inputs"] m = self.num_inducing + # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter - L_inv = jnp.linalg.cholesky(S_inv) - C = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) + + # S⁻¹ = LLᵀ + L = jnp.linalg.cholesky(S_inv) + + # C = L⁻¹I + C = jsp.linalg.solve_triangular(L, I(m), lower=True) + # S = CᵀC S = jnp.matmul(C.T, C) + + # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) Kzz = gram(self.prior.kernel, z, params["kernel"]) @@ -363,12 +371,144 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) μt = self.prior.mean_function(t, params["mean_function"]) - A = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) - B = jsp.linalg.solve_triangular(Lz.T, A, lower=False) - V = jnp.matmul(B.T, C.T) + + # Lz⁻¹ Kzt + Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + + # Kzz⁻¹ Kzt + Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) + + # Ktz Kzz⁻¹ Cᵀ + Ktz_Kzz_inv_CT = jnp.matmul(Kzz_inv_Kzt.T, C.T) - mean = μt + jnp.matmul(B.T, mu - μz) - covariance = Ktt - jnp.matmul(A.T, A) + jnp.matmul(V, V.T) + # μt + Ktz Kzz⁻¹ (μ - μz) + mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) + + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = CᵀC] + covariance = Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_CT, Ktz_Kzz_inv_CT.T) + + return dx.MultivariateNormalFullCovariance( + jnp.atleast_1d(mean.squeeze()), covariance + ) + + return predict_fn + + +@dataclass +class ExpectationVariationalGaussian(AbstractVariationalFamily): + """The variational Gaussian family of probability distributions.""" + prior: Prior + inducing_inputs: Array + name: str = "Natural Gaussian" + expectation_vector: Optional[Array] = None + expectation_matrix: Optional[Array] = None + jitter: Optional[float] = DEFAULT_JITTER + + def __post_init__(self): + """Initialise the variational Gaussian distribution.""" + self.num_inducing = self.inducing_inputs.shape[0] + add_parameter("inducing_inputs", Identity) + + m = self.num_inducing + + if self.expectation_vector is None: + self.expectation_vector = jnp.zeros((m, 1)) + add_parameter("natural_vector", Identity) + + if self.expectation_matrix is None: + self.expectation_matrix = I(m) + add_parameter("natural_matrix", Identity) + + @property + def params(self) -> Dict: + """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" + return concat_dictionaries( + self.prior.params, { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "natural_vector": self.natural_vector, + "natural_matrix": self.natural_matrix} + } + ) + + def prior_kl(self, params: Dict) -> Array: + """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. + + Args: + params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. + + Returns: + Array: The KL-divergence between our variational approximation and the GP prior. + """ + natural_vector = params["variational_family"]["natural_vector"] + natural_matrix = params["variational_family"]["natural_matrix"] + z = params["variational_family"]["inducing_inputs"] + m = self.num_inducing + + mu = natural_vector + S = natural_matrix - jnp.matmul(mu, mu.T) + S += I(m) * self.jitter + sqrt = jnp.linalg.cholesky(S) + + μz = self.prior.mean_function(z, params["mean_function"]) + Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz += I(m) * self.jitter + Lz = jnp.linalg.cholesky(Kzz) + + qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) + pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + + return qu.kl_divergence(pu) + + def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: + """Compute the predictive distribution of the GP at the test inputs. + + Args: + params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + + Returns: + Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + """ + natural_vector = params["variational_family"]["natural_vector"] + natural_matrix = params["variational_family"]["natural_matrix"] + z = params["variational_family"]["inducing_inputs"] + m = self.num_inducing + + # μ = η₁ + mu = natural_vector + + # S = η₂ - η₁ η₁ᵀ + S = natural_matrix - jnp.matmul(mu, mu.T) + S += I(m) * self.jitter + + # S = sqrt sqrtᵀ + sqrt = jnp.linalg.cholesky(S) + + Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz += I(m) * self.jitter + Lz = jnp.linalg.cholesky(Kzz) + μz = self.prior.mean_function(z, params["mean_function"]) + + def predict_fn(test_inputs: Array) -> dx.Distribution: + t = test_inputs + Ktt = gram(self.prior.kernel, t, params["kernel"]) + Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) + μt = self.prior.mean_function(t, params["mean_function"]) + + # Lz⁻¹ Kzt + Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) + + # Kzz⁻¹ Kzt + Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt , lower=False) + + # Ktz Kzz⁻¹ sqrt + Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) + + # μt + Ktz Kzz⁻¹ (μ - μz) + mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) + + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] + covariance = Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt ) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance From 0604cddb352d4440e6391d062dd7b72fd25c538d Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 8 Jun 2022 15:17:34 +0100 Subject: [PATCH 31/66] Update variational_families.py --- gpjax/variational_families.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index d47a04c5..e5d86233 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -399,7 +399,7 @@ class ExpectationVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" prior: Prior inducing_inputs: Array - name: str = "Natural Gaussian" + name: str = "Expectation Gaussian" expectation_vector: Optional[Array] = None expectation_matrix: Optional[Array] = None jitter: Optional[float] = DEFAULT_JITTER @@ -413,21 +413,21 @@ def __post_init__(self): if self.expectation_vector is None: self.expectation_vector = jnp.zeros((m, 1)) - add_parameter("natural_vector", Identity) + add_parameter("expectation_vector", Identity) if self.expectation_matrix is None: self.expectation_matrix = I(m) - add_parameter("natural_matrix", Identity) + add_parameter("expectation_matrix", Identity) @property def params(self) -> Dict: - """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" + """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" return concat_dictionaries( self.prior.params, { "variational_family": { "inducing_inputs": self.inducing_inputs, - "natural_vector": self.natural_vector, - "natural_matrix": self.natural_matrix} + "expectation_vector": self.expectation_vector, + "expectation_matrix": self.expectation_matrix} } ) @@ -440,13 +440,13 @@ def prior_kl(self, params: Dict) -> Array: Returns: Array: The KL-divergence between our variational approximation and the GP prior. """ - natural_vector = params["variational_family"]["natural_vector"] - natural_matrix = params["variational_family"]["natural_matrix"] + expectation_vector = params["variational_family"]["expectation_vector"] + expectation_matrix = params["variational_family"]["expectation_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - mu = natural_vector - S = natural_matrix - jnp.matmul(mu, mu.T) + mu = expectation_vector + S = expectation_matrix - jnp.matmul(mu, mu.T) S += I(m) * self.jitter sqrt = jnp.linalg.cholesky(S) @@ -469,16 +469,16 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ - natural_vector = params["variational_family"]["natural_vector"] - natural_matrix = params["variational_family"]["natural_matrix"] + expectation_vector = params["variational_family"]["expectation_vector"] + expectation_matrix = params["variational_family"]["expectation_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing # μ = η₁ - mu = natural_vector + mu = expectation_vector # S = η₂ - η₁ η₁ᵀ - S = natural_matrix - jnp.matmul(mu, mu.T) + S = expectation_matrix - jnp.matmul(mu, mu.T) S += I(m) * self.jitter # S = sqrt sqrtᵀ @@ -517,9 +517,6 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: return predict_fn - - - @dataclass class CollapsedVariationalGaussian(AbstractVariationalFamily): """Collapsed variational Gaussian family of probability distributions. @@ -623,4 +620,4 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: jnp.atleast_1d(mean.squeeze()), covariance ) - return predict_fn \ No newline at end of file + return predict_fn From 700acfb19a4fb45a515f6be5bef7f7b1bf83eb47 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 8 Jun 2022 15:54:36 +0100 Subject: [PATCH 32/66] Update test_variational_families.py Add test for expectation parameterisation. --- tests/test_variational_families.py | 78 ++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 8c61d372..b2856259 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -238,3 +238,81 @@ def test_natural_variational_gaussian(n_inducing, n_test): assert isinstance(sigma, jnp.ndarray) assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) + + +@pytest.mark.parametrize("n_test", [1, 10]) +@pytest.mark.parametrize("n_inducing", [1, 10, 20]) +def test_expectation_variational_gaussian(n_inducing, n_test): + prior = gpx.Prior(kernel=gpx.RBF()) + + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) + test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) + + + variational_family = gpx.variational_families.ExpectationVariationalGaussian( + prior=prior, + inducing_inputs=inducing_inputs + ) + + # Test init + assert variational_family.num_inducing == n_inducing + + assert jnp.sum(variational_family.expectation_vector) == 0.0 + assert variational_family.expectation_vector.shape == (n_inducing, 1) + + assert variational_family.expectation_matrix.shape == ( + n_inducing, + n_inducing, + ) + assert jnp.all(jnp.diag(variational_family.expectation_matrix) == 1.0) + + params = gpx.config.get_defaults() + assert "variational_root_covariance" in params["transformations"].keys() + assert "variational_mean" in params["transformations"].keys() + + assert (variational_family.expectation_matrix == jnp.eye(n_inducing)).all() + assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() + + # params + params = variational_family.params + assert isinstance(params, dict) + assert "inducing_inputs" in params["variational_family"].keys() + assert "expectation_vector" in params["variational_family"].keys() + assert "expectation_matrix" in params["variational_family"].keys() + + assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) + assert params["variational_family"]["expectation_vector"].shape == (n_inducing, 1) + assert params["variational_family"]["expectation_matrix"].shape == (n_inducing, n_inducing) + + assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["expectation_vector"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["expectation_matrix"], jnp.DeviceArray) + + params = gpx.config.get_defaults() + assert "expectation_vector" in params["transformations"].keys() + assert "expectation_matrix" in params["transformations"].keys() + + assert (variational_family.expectation_matrix == jnp.eye(n_inducing)).all() + assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() + + + #Test KL + params = variational_family.params + kl = variational_family.prior_kl(params) + assert isinstance(kl, jnp.ndarray) + + # Test predictions + predictive_dist_fn = variational_family(params) + assert isinstance(predictive_dist_fn, tp.Callable) + + predictive_dist = predictive_dist_fn(test_inputs) + assert isinstance(predictive_dist, dx.Distribution) + + mu = predictive_dist.mean() + sigma = predictive_dist.covariance() + + assert isinstance(mu, jnp.ndarray) + assert isinstance(sigma, jnp.ndarray) + assert mu.shape == (n_test,) + assert sigma.shape == (n_test, n_test) + From a96281f84609adb38e2fc8c3fe7fc74205807ac0 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 8 Jun 2022 20:17:18 +0100 Subject: [PATCH 33/66] Natural gradients sketch. A rough initial schematic of implementing natural gradients. --- gpjax/natural_gradients.py | 118 +++++++++++++++++++++++++++++++++++++ gpjax/parameters.py | 13 ++++ 2 files changed, 131 insertions(+) create mode 100644 gpjax/natural_gradients.py diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py new file mode 100644 index 00000000..ef688958 --- /dev/null +++ b/gpjax/natural_gradients.py @@ -0,0 +1,118 @@ +from cmath import exp +import typing as tp +import jax.numpy as jnp +import jax.scipy as jsp +from jax import jacobian + +from .config import get_defaults +from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian +from .variational_inference import StochasticVI +from .utils import I +from .gps import AbstractPosterior +from .types import Dataset +from .parameters import build_identity, transform + +DEFAULT_JITTER = get_defaults()["jitter"] + +# CURRENTLY THIS FILE IS A FIRST SKETCH OF NATURAL GRADIENTS in GPJax. + +# Below is correct, but it might be better to pass in params (i.e., all svgp params) and return a dictionary that gives svgp params +def natural_to_expectation(natural_params: dict, jitter: float = DEFAULT_JITTER): + """ + Converts natural parameters to expectation parameters. + Args: + natural_params: A dictionary of natural parameters. + jitter: A small value to prevent numerical instability. + Returns: + A dictionary of expectation parameters. + """ + + natural_matrix = natural_params["natural_matrix"] + natural_vector = natural_params["natural_vector"] + m = natural_vector.shape[0] + + # S⁻¹ = -2θ₂ + S_inv = -2 * natural_matrix + S_inv += I(m) * jitter + + # S⁻¹ = LLᵀ + L = jnp.linalg.cholesky(S_inv) + + # C = L⁻¹I + C = jsp.linalg.solve_triangular(L, I(m), lower=True) + + # S = CᵀC + S = jnp.matmul(C.T, C) + + # μ = Sθ₁ + mu = jnp.matmul(S, natural_vector) + + # η₁ = μ + expectation_vector = mu + + # η₂ = S + η₁ η₁ᵀ + expectation_matrix = S + jnp.matmul(mu, mu.T) + + return {"expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix} + + +# This is a function that you create before training. This can be used to get the elbo for the nexpectation parameterisation. +# Here it is assumed that the parameters have already been transformed prior to being passed to the returned function. +def get_expectation_elbo(posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + train_data: Dataset, + ): + """ + Computes evidence lower bound (ELBO) for the expectation parameterisation. + Args: + posterior: An instance of AbstractPosterior. + variational_family: An instance of AbstractVariationalFamily. + Returns: + Callable: A function that computes ELBO. + """ + q = variational_family + expectaction_q = ExpectationVariationalGaussian(prior=q.prior, inducing_inputs = q.inducing_inputs) + svgp = StochasticVI(posterior=posterior, variational_family=expectaction_q) + transformations = build_identity(svgp.params) + + return svgp.elbo(train_data, transformations) + + +def natural_gradients(params: dict, + transformations: dict, + expectation_elbo: tp.Callable, + nat_to_xi: tp.Callable, + xi_to_nat: tp.Callable, + batch, +) -> dict: + """ + Computes natural gradients for a variational family. + Args: + params (tp.Dict): A dictionary of parameters. + variational_family: A variational family. + nat_to_xi: A function that converts natural parameters to variational parameters xi. + xi_to_nat: A function that converts variational parameters xi to natural parameters. + transformations (tp.Dict): A dictionary of transformations. + Returns: + tp.Dict: Dictionary of natural gradients. + """ + # Transform the parameters. + params = transform(params, transformations) + + # Need to stop gradients for hyperparameters. + + natural_params = xi_to_nat(params) + + # Gradient function ∂ξ/∂θ: + dxi_dnat = jacobian(nat_to_xi)(natural_params) + + expectation_params = natural_to_expectation(natural_params) + expectation_elbo = expectation_elbo(expectation_params, batch) + + # Compute gradient ∂L/∂η: + dL_dnat = jacobian(expectation_elbo)(expectation_params) + + # Compute natural gradient: + nat_grads = jnp.matmul(dxi_dnat, dL_dnat.T) #<- Some pytree operations are needed here. + + return nat_grads diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 62cc40ec..4ade47e9 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -172,6 +172,19 @@ def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) +def build_identity(params: tp.Dict) -> tp.Dict: + """" + Args: + params (tp.Dict): The parameter set for which trainable statuses should be derived from. + + Returns: + tp.Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. + """ + # Copy dictionary structure + prior_container = deepcopy(params) + + return jax.tree_map(lambda _: Identity.forward, prior_container) + ################################ # Priors From 35173f8f9bcf8188e5c53feabba85d6230bd9c05 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 17:14:04 +0100 Subject: [PATCH 34/66] Add notion of "moments" Add notion of "moments" as a moment parameterisation for a variational Gaussian. --- gpjax/variational_families.py | 52 ++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e5d86233..e18bafe1 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -82,14 +82,13 @@ def __post_init__(self): def _initialise_params(self, key: jnp.DeviceArray) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( - self.prior._initialise_params(key), - { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "variational_mean": self.variational_mean, - "variational_root_covariance": self.variational_root_covariance, + self.prior._initialise_params(key), { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "moments": {"variational_mean": self.variational_mean, + "variational_root_covariance": self.variational_root_covariance} } - }, + } ) def prior_kl(self, params: Dict) -> Float[Array, "1"]: @@ -104,8 +103,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Array: The KL-divergence between our variational approximation and the GP prior. """ - mu = params["variational_family"]["variational_mean"] - sqrt = params["variational_family"]["variational_root_covariance"] + mu = params["variational_family"]["moments"]["variational_mean"] + sqrt = params["variational_family"]["moments"]["variational_root_covariance"] m = self.num_inducing z = params["variational_family"]["inducing_inputs"] μz = self.prior.mean_function(z, params["mean_function"]) @@ -131,8 +130,8 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ - mu = params["variational_family"]["variational_mean"] - sqrt = params["variational_family"]["variational_root_covariance"] + mu = params["variational_family"]["moments"]["variational_mean"] + sqrt = params["variational_family"]["moments"]["variational_root_covariance"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing @@ -197,8 +196,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Array: The KL-divergence between our variational approximation and the GP prior. """ - mu = params["variational_family"]["variational_mean"] - sqrt = params["variational_family"]["variational_root_covariance"] + mu = params["variational_family"]["moments"]["variational_mean"] + sqrt = params["variational_family"]["moments"]["variational_root_covariance"] m = self.num_inducing qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) @@ -219,8 +218,8 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ - mu = params["variational_family"]["variational_mean"] - sqrt = params["variational_family"]["variational_root_covariance"] + mu = params["variational_family"]["moments"]["variational_mean"] + sqrt = params["variational_family"]["moments"]["variational_root_covariance"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing @@ -291,9 +290,11 @@ def params(self) -> Dict: self.prior.params, { "variational_family": { "inducing_inputs": self.inducing_inputs, - "natural_vector": self.natural_vector, + "moments": {"natural_vector": self.natural_vector, "natural_matrix": self.natural_matrix} } + } + ) def prior_kl(self, params: Dict) -> Array: @@ -305,8 +306,8 @@ def prior_kl(self, params: Dict) -> Array: Returns: Array: The KL-divergence between our variational approximation and the GP prior. """ - natural_vector = params["variational_family"]["natural_vector"] - natural_matrix = params["variational_family"]["natural_matrix"] + natural_vector = params["variational_family"]["moments"]["natural_vector"] + natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing @@ -340,8 +341,8 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ - natural_vector = params["variational_family"]["natural_vector"] - natural_matrix = params["variational_family"]["natural_matrix"] + natural_vector = params["variational_family"]["moments"]["natural_vector"] + natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing @@ -426,9 +427,10 @@ def params(self) -> Dict: self.prior.params, { "variational_family": { "inducing_inputs": self.inducing_inputs, - "expectation_vector": self.expectation_vector, + "moments": {"expectation_vector": self.expectation_vector, "expectation_matrix": self.expectation_matrix} } + } ) def prior_kl(self, params: Dict) -> Array: @@ -440,8 +442,8 @@ def prior_kl(self, params: Dict) -> Array: Returns: Array: The KL-divergence between our variational approximation and the GP prior. """ - expectation_vector = params["variational_family"]["expectation_vector"] - expectation_matrix = params["variational_family"]["expectation_matrix"] + expectation_vector = params["variational_family"]["moments"]["expectation_vector"] + expectation_matrix = params["variational_family"]["moments"]["expectation_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing @@ -469,8 +471,8 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Returns: Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ - expectation_vector = params["variational_family"]["expectation_vector"] - expectation_matrix = params["variational_family"]["expectation_matrix"] + expectation_vector = params["variational_family"]["moments"]["expectation_vector"] + expectation_matrix = params["variational_family"]["moments"]["expectation_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing From 218ebf456e4ded45b2a28ff62e3ab5989668c26b Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 17:19:44 +0100 Subject: [PATCH 35/66] Add AbstractVariationalGaussian class. --- gpjax/variational_families.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e18bafe1..fd8a16e7 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -43,23 +43,26 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Predict the GP's output given the input.""" raise NotImplementedError +@dataclass +class AbstractVariationalGaussian(AbstractVariationalFamily): + """The variational Gaussian family of probability distributions.""" + prior: Prior + inducing_inputs: Array + name: str = "Gaussian" + jitter: Optional[float] = DEFAULT_JITTER @dataclass -class VariationalGaussian(AbstractVariationalFamily): +class VariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions. The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z and the distribution over the inducing inputs is q(u) = N(μ, S). We parameterise this over μ and sqrt with S = sqrt sqrtᵀ. """ - - prior: Prior - inducing_inputs: Float[Array, "M D"] - name: str = "Variational Gaussian" - variational_mean: Optional[Float[Array, "M Q"]] = None - variational_root_covariance: Optional[Float[Array, "M M"]] = None + variational_mean: Optional[Array] = None + variational_root_covariance: Optional[Array] = None diag: Optional[bool] = False - jitter: Optional[float] = DEFAULT_JITTER + def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -259,14 +262,11 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: @dataclass -class NaturalVariationalGaussian(AbstractVariationalFamily): +class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions.""" - prior: Prior - inducing_inputs: Array name: str = "Natural Gaussian" natural_vector: Optional[Array] = None natural_matrix: Optional[Array] = None - jitter: Optional[float] = DEFAULT_JITTER def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -396,14 +396,11 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: @dataclass -class ExpectationVariationalGaussian(AbstractVariationalFamily): +class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions.""" - prior: Prior - inducing_inputs: Array name: str = "Expectation Gaussian" expectation_vector: Optional[Array] = None expectation_matrix: Optional[Array] = None - jitter: Optional[float] = DEFAULT_JITTER def __post_init__(self): """Initialise the variational Gaussian distribution.""" From 642cc73a8c82685cbe81a3f4e264e78444948381 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 17:24:55 +0100 Subject: [PATCH 36/66] Update test_variational_families.py --- tests/test_variational_families.py | 41 +++++++++++++----------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index b2856259..44e98360 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -56,21 +56,16 @@ def test_variational_gaussian(diag, n_inducing, n_test, whiten): params = variational_family._initialise_params(jr.PRNGKey(123)) assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() - assert "variational_mean" in params["variational_family"].keys() - assert "variational_root_covariance" in params["variational_family"].keys() + assert "variational_mean" in params["variational_family"]["moments"].keys() + assert "variational_root_covariance" in params["variational_family"]["moments"].keys() assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["variational_mean"].shape == (n_inducing, 1) - assert params["variational_family"]["variational_root_covariance"].shape == ( - n_inducing, - n_inducing, - ) + assert params["variational_family"]["moments"]["variational_mean"].shape == (n_inducing, 1) + assert params["variational_family"]["moments"]["variational_root_covariance"].shape == (n_inducing, n_inducing) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["variational_mean"], jnp.DeviceArray) - assert isinstance( - params["variational_family"]["variational_root_covariance"], jnp.DeviceArray - ) + assert isinstance(params["variational_family"]["moments"]["variational_mean"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["moments"]["variational_root_covariance"], jnp.DeviceArray) params = gpx.config.get_defaults() assert "variational_root_covariance" in params["transformations"].keys() @@ -200,16 +195,16 @@ def test_natural_variational_gaussian(n_inducing, n_test): params = variational_family.params assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() - assert "natural_vector" in params["variational_family"].keys() - assert "natural_matrix" in params["variational_family"].keys() + assert "natural_vector" in params["variational_family"]["moments"].keys() + assert "natural_matrix" in params["variational_family"]["moments"].keys() assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["natural_vector"].shape == (n_inducing, 1) - assert params["variational_family"]["natural_matrix"].shape == (n_inducing, n_inducing) + assert params["variational_family"]["moments"]["natural_vector"].shape == (n_inducing, 1) + assert params["variational_family"]["moments"]["natural_matrix"].shape == (n_inducing, n_inducing) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["natural_vector"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["natural_matrix"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["moments"]["natural_vector"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["moments"]["natural_matrix"], jnp.DeviceArray) params = gpx.config.get_defaults() assert "natural_vector" in params["transformations"].keys() @@ -277,16 +272,16 @@ def test_expectation_variational_gaussian(n_inducing, n_test): params = variational_family.params assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() - assert "expectation_vector" in params["variational_family"].keys() - assert "expectation_matrix" in params["variational_family"].keys() + assert "expectation_vector" in params["variational_family"]["moments"].keys() + assert "expectation_matrix" in params["variational_family"]["moments"].keys() assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["expectation_vector"].shape == (n_inducing, 1) - assert params["variational_family"]["expectation_matrix"].shape == (n_inducing, n_inducing) + assert params["variational_family"]["moments"]["expectation_vector"].shape == (n_inducing, 1) + assert params["variational_family"]["moments"]["expectation_matrix"].shape == (n_inducing, n_inducing) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["expectation_vector"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["expectation_matrix"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["moments"]["expectation_vector"], jnp.DeviceArray) + assert isinstance(params["variational_family"]["moments"]["expectation_matrix"], jnp.DeviceArray) params = gpx.config.get_defaults() assert "expectation_vector" in params["transformations"].keys() From 347be8b37e4d8f3fcd3dcfdbe819cec4d86fb7b2 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 17:42:26 +0100 Subject: [PATCH 37/66] Update natural_gradients.py --- gpjax/natural_gradients.py | 88 ++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index ef688958..6b5619e3 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,8 +1,9 @@ -from cmath import exp import typing as tp import jax.numpy as jnp import jax.scipy as jsp from jax import jacobian +import distrax as dx +from jax import lax from .config import get_defaults from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian @@ -14,21 +15,19 @@ DEFAULT_JITTER = get_defaults()["jitter"] -# CURRENTLY THIS FILE IS A FIRST SKETCH OF NATURAL GRADIENTS in GPJax. -# Below is correct, but it might be better to pass in params (i.e., all svgp params) and return a dictionary that gives svgp params -def natural_to_expectation(natural_params: dict, jitter: float = DEFAULT_JITTER): +def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER): """ Converts natural parameters to expectation parameters. Args: - natural_params: A dictionary of natural parameters. - jitter: A small value to prevent numerical instability. + natural_moments: A dictionary of natural parameters. + jitter (float): A small value to prevent numerical instability. Returns: - A dictionary of expectation parameters. + tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ - natural_matrix = natural_params["natural_matrix"] - natural_vector = natural_params["natural_vector"] + natural_matrix = natural_moments["natural_matrix"] + natural_vector = natural_moments["natural_vector"] m = natural_vector.shape[0] # S⁻¹ = -2θ₂ @@ -56,63 +55,76 @@ def natural_to_expectation(natural_params: dict, jitter: float = DEFAULT_JITTER) return {"expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix} -# This is a function that you create before training. This can be used to get the elbo for the nexpectation parameterisation. -# Here it is assumed that the parameters have already been transformed prior to being passed to the returned function. -def get_expectation_elbo(posterior: AbstractPosterior, +def _expectation_elbo(posterior: AbstractPosterior, variational_family: AbstractVariationalFamily, train_data: Dataset, ): """ - Computes evidence lower bound (ELBO) for the expectation parameterisation. + Construct evidence lower bound (ELBO) for varational Gaussian under the expectation parameterisation. Args: posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. Returns: Callable: A function that computes ELBO. """ - q = variational_family - expectaction_q = ExpectationVariationalGaussian(prior=q.prior, inducing_inputs = q.inducing_inputs) - svgp = StochasticVI(posterior=posterior, variational_family=expectaction_q) - transformations = build_identity(svgp.params) - - return svgp.elbo(train_data, transformations) - - -def natural_gradients(params: dict, - transformations: dict, - expectation_elbo: tp.Callable, - nat_to_xi: tp.Callable, - xi_to_nat: tp.Callable, - batch, + evg = ExpectationVariationalGaussian(prior= variational_family.prior, + inducing_inputs = variational_family.inducing_inputs, + ) + svgp = StochasticVI(posterior=posterior, variational_family=evg) + + return svgp.elbo(train_data, build_identity(svgp.params)) + + +def natural_gradients( + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + train_data: Dataset, + params: dict, + transformations: dict, + nat_to_moments: dx.Bijector, + batch, ) -> dict: """ Computes natural gradients for a variational family. Args: - params (tp.Dict): A dictionary of parameters. - variational_family: A variational family. - nat_to_xi: A function that converts natural parameters to variational parameters xi. - xi_to_nat: A function that converts variational parameters xi to natural parameters. - transformations (tp.Dict): A dictionary of transformations. + posterior (AbstractPosterior): An instance of AbstractPosterior. + variational_family(AbstractVariationalFamily): An instance of AbstractVariationalFamily. + train_data (Dataset): Training Dataset. + params (tp.Dict): A dictionary of model parameters. + transformations (tp.Dict): A dictionary of parameter transformations. + nat_to_moments (dx.Bijector): A bijector between natural and the chosen parameterisations of the Gaussian variational moments. Returns: tp.Dict: Dictionary of natural gradients. """ # Transform the parameters. params = transform(params, transformations) - # Need to stop gradients for hyperparameters. + # Get moments and stop gradients for non-moment parameters. + moments = params["variational_family"]["moments"] - natural_params = xi_to_nat(params) + other_var_params = {k:v for k,v in params["variational_family"].items() if k!="moments"} + other_params = lax.stop_gradient({**{k:v for k,v in params.items() if k!="variational_family"}, **other_var_params}) + + # Convert moments to natural parameterisation. + natural_moments = nat_to_moments.inverse(moments) # Gradient function ∂ξ/∂θ: - dxi_dnat = jacobian(nat_to_xi)(natural_params) + dxi_dnat = jacobian(nat_to_moments.forward)(natural_moments) + + # Convert natural moments to expectation moments. + expectation_moments = natural_to_expectation(natural_moments) + + # Create dictionary of all parameters for the ELBO under the expectation parameterisation. + expectation_params = other_params + expectation_params["variational_family"]["moments"] = expectation_moments - expectation_params = natural_to_expectation(natural_params) - expectation_elbo = expectation_elbo(expectation_params, batch) + # Compute ELBO. + expectation_elbo = _expectation_elbo(posterior, variational_family, train_data)(expectation_params, batch) # Compute gradient ∂L/∂η: dL_dnat = jacobian(expectation_elbo)(expectation_params) # Compute natural gradient: - nat_grads = jnp.matmul(dxi_dnat, dL_dnat.T) #<- Some pytree operations are needed here. + nat_grads = jnp.matmul(dxi_dnat, dL_dnat.T) #<---- PSUEDO CODE - TO DO - Pytree operations needed here. return nat_grads From 37bedd6f543df13eb000f52649dd29c9876503c7 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 17:48:41 +0100 Subject: [PATCH 38/66] Update natural_gradients.py --- gpjax/natural_gradients.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 6b5619e3..b391b46c 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -102,6 +102,7 @@ def natural_gradients( # Get moments and stop gradients for non-moment parameters. moments = params["variational_family"]["moments"] + # TO DO -> CAN WE WRITE BELOW AS A ONE LINER? other_var_params = {k:v for k,v in params["variational_family"].items() if k!="moments"} other_params = lax.stop_gradient({**{k:v for k,v in params.items() if k!="variational_family"}, **other_var_params}) From fa6ca2b151bd1baf9b6686c6548d46a007abe504 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 15 Jun 2022 19:11:38 +0100 Subject: [PATCH 39/66] Update. --- gpjax/natural_gradients.py | 21 +++++++++++++-------- gpjax/parameters.py | 2 +- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index b391b46c..3685dfd5 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -4,6 +4,7 @@ from jax import jacobian import distrax as dx from jax import lax +import jax from .config import get_defaults from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian @@ -75,15 +76,17 @@ def _expectation_elbo(posterior: AbstractPosterior, return svgp.elbo(train_data, build_identity(svgp.params)) + +#DOES NOT WORK YET. def natural_gradients( posterior: AbstractPosterior, variational_family: AbstractVariationalFamily, train_data: Dataset, params: dict, transformations: dict, - nat_to_moments: dx.Bijector, batch, -) -> dict: + nat_to_moments: tp.Optional[dx.Bijector] = Identity, + ) -> dict: """ Computes natural gradients for a variational family. Args: @@ -104,8 +107,9 @@ def natural_gradients( # TO DO -> CAN WE WRITE BELOW AS A ONE LINER? other_var_params = {k:v for k,v in params["variational_family"].items() if k!="moments"} - other_params = lax.stop_gradient({**{k:v for k,v in params.items() if k!="variational_family"}, **other_var_params}) - + other_params = lax.stop_gradient({**{k:v for k,v in params.items() if k!="variational_family"}, **{"variational_family": other_var_params}}) + + # Convert moments to natural parameterisation. natural_moments = nat_to_moments.inverse(moments) @@ -120,12 +124,13 @@ def natural_gradients( expectation_params["variational_family"]["moments"] = expectation_moments # Compute ELBO. - expectation_elbo = _expectation_elbo(posterior, variational_family, train_data)(expectation_params, batch) + expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) # Compute gradient ∂L/∂η: - dL_dnat = jacobian(expectation_elbo)(expectation_params) + dL_dnat = jacobian(expectation_elbo)(expectation_params, batch) # Compute natural gradient: - nat_grads = jnp.matmul(dxi_dnat, dL_dnat.T) #<---- PSUEDO CODE - TO DO - Pytree operations needed here. - return nat_grads + nat_grads = jax.tree_multimap(lambda x, y: jnp.matmul(x.T, y), dxi_dnat, dL_dnat) + + return nat_grads \ No newline at end of file diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 4ade47e9..0b6243b9 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -14,7 +14,7 @@ from .types import PRNGKeyType from .utils import merge_dictionaries -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) +Identity = dx.Lambda(forward = lambda x: x, inverse = lambda x: x) ################################ From 02f068b941117111d54ad2e6e8f0a7100c8bef36 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Thu, 16 Jun 2022 15:07:04 +0100 Subject: [PATCH 40/66] Minimal working natural gradient functions for NAT PARAMETERISATION. --- gpjax/natural_gradients.py | 185 +++++++++++++++++++++++++++---------- gpjax/parameters.py | 19 +++- 2 files changed, 154 insertions(+), 50 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 3685dfd5..0fadf437 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,23 +1,23 @@ +from copy import deepcopy +from multiprocessing.dummy import Array import typing as tp import jax.numpy as jnp import jax.scipy as jsp -from jax import jacobian import distrax as dx -from jax import lax -import jax +from jax import lax, value_and_grad from .config import get_defaults -from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian +from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian, NaturalVariationalGaussian from .variational_inference import StochasticVI from .utils import I from .gps import AbstractPosterior from .types import Dataset -from .parameters import build_identity, transform +from .parameters import Identity, build_identity, transform, build_trainables_false, build_trainables_true, trainable_params DEFAULT_JITTER = get_defaults()["jitter"] -def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER): +def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER) -> dict: """ Converts natural parameters to expectation parameters. Args: @@ -59,7 +59,7 @@ def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER def _expectation_elbo(posterior: AbstractPosterior, variational_family: AbstractVariationalFamily, train_data: Dataset, - ): + ) -> tp.Callable[[dict, Dataset], float]: """ Construct evidence lower bound (ELBO) for varational Gaussian under the expectation parameterisation. Args: @@ -76,61 +76,150 @@ def _expectation_elbo(posterior: AbstractPosterior, return svgp.elbo(train_data, build_identity(svgp.params)) +def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: + """ + Stops gradients for non-moment parameters. + Args: + params: A dictionary of parameters. + Returns: + tp.Dict: A dictionary of parameters with stopped gradients. + """ + trainables = build_trainables_false(params) + moment_trainables = build_trainables_true(params["variational_family"]["moments"]) + trainables["variational_family"]["moments"] = moment_trainables + params = trainable_params(params, trainables) + return params + +def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: + """ + Stops gradients for moment parameters. + Args: + params: A dictionary of parameters. + Returns: + tp.Dict: A dictionary of parameters with stopped gradients. + """ + trainables = build_trainables_true(params) + moment_trainables = build_trainables_false(params["variational_family"]["moments"]) + trainables["variational_family"]["moments"] = moment_trainables + params = trainable_params(params, trainables) + return params + -#DOES NOT WORK YET. def natural_gradients( - posterior: AbstractPosterior, - variational_family: AbstractVariationalFamily, + stochastic_vi: StochasticVI, train_data: Dataset, - params: dict, transformations: dict, - batch, - nat_to_moments: tp.Optional[dx.Bijector] = Identity, - ) -> dict: + #bijector = tp.Optional[dx.Bijector] = Identity, #bijector: A bijector to convert between the user chosen parameterisation and the natural parameters. + ) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: """ - Computes natural gradients for a variational family. + Computes natural gradients for variational Gaussian. Args: - posterior (AbstractPosterior): An instance of AbstractPosterior. - variational_family(AbstractVariationalFamily): An instance of AbstractVariationalFamily. - train_data (Dataset): Training Dataset. - params (tp.Dict): A dictionary of model parameters. - transformations (tp.Dict): A dictionary of parameter transformations. - nat_to_moments (dx.Bijector): A bijector between natural and the chosen parameterisations of the Gaussian variational moments. + posterior: An instance of AbstractPosterior. + variational_family: An instance of AbstractVariationalFamily. + train_data: A Dataset. + transformations: A dictionary of transformations. Returns: - tp.Dict: Dictionary of natural gradients. + Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. """ - # Transform the parameters. - params = transform(params, transformations) + posterior = stochastic_vi.posterior + variational_family = stochastic_vi.variational_family - # Get moments and stop gradients for non-moment parameters. - moments = params["variational_family"]["moments"] + # The ELBO under the user chosen parameterisation xi. + xi_elbo = stochastic_vi.elbo(train_data, transformations) + + # The ELBO under the expectation parameterisation, L(η). + expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) - # TO DO -> CAN WE WRITE BELOW AS A ONE LINER? - other_var_params = {k:v for k,v in params["variational_family"].items() if k!="moments"} - other_params = lax.stop_gradient({**{k:v for k,v in params.items() if k!="variational_family"}, **{"variational_family": other_var_params}}) + if isinstance(variational_family, NaturalVariationalGaussian): + def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + """ + Computes the natural gradients of the ELBO. + """ + # Transform parameters to constrained space. + params = transform(params, transformations) + + # Get natural moments θ. + natural_moments = params["variational_family"]["moments"] + # Get expectation moments η. + expectation_moments = natural_to_expectation(natural_moments) - # Convert moments to natural parameterisation. - natural_moments = nat_to_moments.inverse(moments) + # Full params with expectation moments. + expectation_params = deepcopy(params) + expectation_params["variational_family"]["moments"] = expectation_moments - # Gradient function ∂ξ/∂θ: - dxi_dnat = jacobian(nat_to_moments.forward)(natural_moments) + # Compute gradient ∂L/∂η: + def loss_fn(params: dict, batch: Dataset) -> Array: + # Determine hyperparameters that should be trained. + trainables["variational_family"]["moments"] = build_trainables_true(params["variational_family"]["moments"]) + params = trainable_params(params, trainables) + + # Stop gradients for non-moment parameters. + params = _stop_gradients_nonmoments(params) - # Convert natural moments to expectation moments. - expectation_moments = natural_to_expectation(natural_moments) + return expectation_elbo(params, batch) - # Create dictionary of all parameters for the ELBO under the expectation parameterisation. - expectation_params = other_params - expectation_params["variational_family"]["moments"] = expectation_moments + value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) - # Compute ELBO. - expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) + return value, dL_dnat - # Compute gradient ∂L/∂η: - dL_dnat = jacobian(expectation_elbo)(expectation_params, batch) - - # Compute natural gradient: - - nat_grads = jax.tree_multimap(lambda x, y: jnp.matmul(x.T, y), dxi_dnat, dL_dnat) - - return nat_grads \ No newline at end of file + else: + #To Do: (DD) add general parameterisation case. + raise NotImplementedError + + # BELOW is (almost working) PSUEDO CODE of what this will look like. + + # def nat_grads_fn(params: dict, batch: Dataset) -> dict: + # # Transform parameters to constrained space. + # params = transform(params, transformations) + + # # Stop gradients for non-moment parameters. + # params = _stop_gradients_nonmoments(params) + + # # Get natural moments θ. + # natural_moments = bijector.inverse(params["variational_family"]["moments"]) + + # # Get expectation moments η. + # expectation_moments = natural_to_expectation(natural_moments) + + # # Gradient function ∂ξ/∂θ: + # #### NEED TO STOP GRADIENTS FOR NON MOMENTS HERE!#### + # dxi_dnat = jacobian(nat_to_moments.forward)(natural_moments) + + # # Full params with expectation moments. + # expectation_params = deepcopy(params) + # expectation_params["variational_family"]["moments"] = expectation_moments + + # # Compute gradient ∂L/∂η: + # def loss_fn(params: dict, batch: Dataset) -> Array: + # # Determine hyperparameters that should be trained. + # params = trainable_params(params, trainables) + + # # Stop gradients for non-moment parameters. + # params = _stop_gradients_nonmoments(params) + + # return expectation_elbo(expectation_params, batch) + + # value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) + + # # ∂ξ/∂θ ∂L/∂η + # nat_grads = jax.tree_multimap(lambda x, y: jnp.matmul(x.T, y), dxi_dnat, dL_dnat) + + # return value, nat_grads + + def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + + def loss_fn(params: dict, batch: Dataset) -> Array: + # Determine hyperparameters that should be trained. + params = trainable_params(params, trainables) + + # Stop gradients for the moment parameters. + params = _stop_gradients_moments(params) + + return xi_elbo(params, batch) + + value, dL_dhyper = value_and_grad(loss_fn)(params, batch) + + return value, dL_dhyper + + return nat_grads_fn, hyper_grads_fn \ No newline at end of file diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 0b6243b9..f09e2479 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -51,7 +51,6 @@ def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: ) return state - def _validate_kwargs(kwargs, params): for k, v in kwargs.items(): if k not in params.keys(): @@ -262,7 +261,7 @@ def prior_checks(priors: dict) -> dict: return priors -def build_trainables(params: tp.Dict) -> tp.Dict: +def build_trainables_true(params: tp.Dict) -> tp.Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is trainable. Args: @@ -278,6 +277,22 @@ def build_trainables(params: tp.Dict) -> tp.Dict: return prior_container +def build_trainables_false(params: tp.Dict) -> tp.Dict: + """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is NOT trainable. + + Args: + params (tp.Dict): The parameter set for which trainable statuses should be derived from. + + Returns: + tp.Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + """ + # Copy dictionary structure + prior_container = deepcopy(params) + # Set all values to zero + prior_container = jax.tree_map(lambda _: False, prior_container) + return prior_container + + def stop_grad(param: tp.Dict, trainable: tp.Dict): """When taking a gradient, we want to stop the gradient from flowing through a parameter if it is not trainable. This is achieved using the model's dictionary of parameters and the corresponding trainability status.""" return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) From 612f2c700cb9378c906e8b0924e69b0855608913 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Thu, 16 Jun 2022 17:31:13 +0100 Subject: [PATCH 41/66] Add tests. --- gpjax/config.py | 3 +- gpjax/natural_gradients.py | 27 +++++-- tests/test_natural_gradients.py | 136 ++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 tests/test_natural_gradients.py diff --git a/gpjax/config.py b/gpjax/config.py index 13695d9e..c330a3da 100644 --- a/gpjax/config.py +++ b/gpjax/config.py @@ -5,7 +5,6 @@ __config = None -Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) Softplus = dx.Lambda( forward=lambda x: jnp.log(1 + jnp.exp(x)), inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), @@ -31,6 +30,8 @@ # logdet = 1 / (1 - jnp.exp(-y)) # return x, logdet +Identity = dx.Lambda(forward = lambda x: x, inverse = lambda x: x) + def get_defaults() -> ConfigDict: """Construct and globally register the config file used within GPJax. diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 0fadf437..a233f02f 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,5 +1,4 @@ from copy import deepcopy -from multiprocessing.dummy import Array import typing as tp import jax.numpy as jnp import jax.scipy as jsp @@ -11,8 +10,8 @@ from .variational_inference import StochasticVI from .utils import I from .gps import AbstractPosterior -from .types import Dataset -from .parameters import Identity, build_identity, transform, build_trainables_false, build_trainables_true, trainable_params +from .types import Dataset, Array +from .parameters import build_identity, transform, build_trainables_false, build_trainables_true, trainable_params DEFAULT_JITTER = get_defaults()["jitter"] @@ -134,6 +133,12 @@ def natural_gradients( def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: """ Computes the natural gradients of the ELBO. + Args: + params: A dictionary of parameters. + trainables: A dictionary of trainables. + batch: A Dataset. + Returns: + dict: A dictionary of natural gradients. """ # Transform parameters to constrained space. params = transform(params, transformations) @@ -151,8 +156,9 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: # Compute gradient ∂L/∂η: def loss_fn(params: dict, batch: Dataset) -> Array: # Determine hyperparameters that should be trained. - trainables["variational_family"]["moments"] = build_trainables_true(params["variational_family"]["moments"]) - params = trainable_params(params, trainables) + trains = deepcopy(trainables) + trains["variational_family"]["moments"] = build_trainables_true(params["variational_family"]["moments"]) + params = trainable_params(params, trains) # Stop gradients for non-moment parameters. params = _stop_gradients_nonmoments(params) @@ -208,6 +214,17 @@ def loss_fn(params: dict, batch: Dataset) -> Array: # return value, nat_grads def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + """ + Computes the hyperparameter gradients of the ELBO. + Args: + params: A dictionary of parameters. + trainables: A dictionary of trainables. + batch: A Dataset. + Returns: + dict: A dictionary of hyperparameter gradients. + """ + # Transform parameters to constrained space. + params = transform(params, transformations) def loss_fn(params: dict, batch: Dataset) -> Array: # Determine hyperparameters that should be trained. diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py new file mode 100644 index 00000000..4fb98d8b --- /dev/null +++ b/tests/test_natural_gradients.py @@ -0,0 +1,136 @@ +import pytest +import jax +import jax.numpy as jnp +from gpjax.natural_gradients import natural_to_expectation, _stop_gradients_nonmoments, _stop_gradients_moments, _expectation_elbo, natural_gradients +import typing as tp +import gpjax as gpx +import jax.random as jr +from gpjax.parameters import recursive_items +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +from jax import jit +import optax as ox + +import gpjax as gpx +import tensorflow as tf + +tf.random.set_seed(42) +key = jr.PRNGKey(123) + +@pytest.mark.parametrize("dim", [1, 2, 3]) +def test_natural_to_expectation(dim): + """ + Converts natural parameters to expectation parameters. + Args: + natural_moments: A dictionary of natural parameters. + jitter (float): A small value to prevent numerical instability. + Returns: + tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. + """ + + natural_matrix = -.5 * jnp.eye(dim) + natural_vector = jnp.zeros((dim, 1)) + + natural_moments = {"natural_matrix": natural_matrix, "natural_vector": natural_vector} + + expectation_moments = natural_to_expectation(natural_moments, jitter=1e-6) + + assert "expectation_vector" in expectation_moments.keys() + assert "expectation_matrix" in expectation_moments.keys() + assert expectation_moments["expectation_vector"].shape == natural_moments["natural_vector"].shape + assert expectation_moments["expectation_matrix"].shape == natural_moments["natural_matrix"].shape + + +def get_data_and_gp(n_datapoints): + x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) + y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 + D = gpx.Dataset(X=x, y=y) + + p = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=n_datapoints) + post = p * lik + return D, post, p + +@pytest.mark.parametrize("jit_fns", [True, False]) +def test_expectation_elbo(jit_fns): + """ + Tests the expectation ELBO. + """ + D, posterior, prior = get_data_and_gp(10) + + z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) + variational_family = gpx.variational_families.ExpectationVariationalGaussian(prior = prior, inducing_inputs=z) + + svgp = gpx.StochasticVI(posterior=posterior, variational_family=variational_family) + + params, _, constrainer, unconstrainer = gpx.initialise(svgp) + + expectation_elbo = _expectation_elbo(posterior, variational_family, D) + + if jit_fns: + elbo_fn = jax.jit(expectation_elbo) + else: + elbo_fn = expectation_elbo + + assert isinstance(elbo_fn, tp.Callable) + elbo_value = elbo_fn(params, D) + assert isinstance(elbo_value, jnp.ndarray) + + # Test gradients + grads = jax.grad(elbo_fn, argnums=0)(params, D) + assert isinstance(grads, tp.Dict) + assert len(grads) == len(params) + + + +# def test_stop_gradients_nonmoments(): +# pass + + +# def test_stop_gradients_moments(): +# pass + +def test_natural_gradients(): + """ + Tests the expectation ELBO. + """ + D, p, prior = get_data_and_gp(10) + + z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) + + Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1) + + likelihood = gpx.Gaussian(num_datapoints=D.n) + prior = gpx.Prior(kernel=gpx.RBF()) + q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) + + svgp = gpx.StochasticVI(posterior=p, variational_family=q) + + params, trainables, constrainers, unconstrainers = gpx.initialise(svgp) + params = gpx.transform(params, unconstrainers) + + batcher = Dbatched.get_batcher() + batch = batcher() + + nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers) + + assert isinstance(nat_grads_fn, tp.Callable) + assert isinstance(hyper_grads_fn, tp.Callable) + + val, nat_grads = nat_grads_fn(params, trainables, batch) + val, hyper_grads = hyper_grads_fn(params, trainables, batch) + + assert isinstance(val, jnp.ndarray) + assert isinstance(nat_grads, tp.Dict) + assert isinstance(hyper_grads, tp.Dict) + + # Need to check moments are zero in hyper_grads: + assert jnp.array([ (v == 0.).all() for v in hyper_grads["variational_family"]["moments"].values()]).all() + + # Check non-moments are zero in nat_grads: + d = jax.tree_map(lambda x: (x==0.).all(), nat_grads) + d["variational_family"]["moments"] = True + + assert jnp.array([v1 == True for k, v1, v2 in recursive_items(d, d)]).all() + From 25a78933d898f80435c20fe649cae8cb31a41322 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Thu, 16 Jun 2022 17:39:09 +0100 Subject: [PATCH 42/66] Update natural_gradients.py --- gpjax/natural_gradients.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index a233f02f..616bbf53 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,4 +1,5 @@ from copy import deepcopy +from this import d import typing as tp import jax.numpy as jnp import jax.scipy as jsp @@ -167,7 +168,14 @@ def loss_fn(params: dict, batch: Dataset) -> Array: value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) - return value, dL_dnat + + + # This is a renaming of the gradient components to match the natural parameterisation pytree. + nat_grad = dL_dnat + nat_grad["variational_family"]["moments"] = {"natural_vector": dL_dnat["variational_family"]["moments"]["expectation_vector"], + "natural_matrix": dL_dnat["variational_family"]["moments"]["expectation_matrix"]} + + return value, nat_grad else: #To Do: (DD) add general parameterisation case. From a5499d0f242112750867175441bd627fb00f0253 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Fri, 17 Jun 2022 15:28:15 +0100 Subject: [PATCH 43/66] Update natural_gradients.py Fix negative, and add sketch optimisation loop. --- gpjax/natural_gradients.py | 61 +++++++++++++++++++++++++++++++++++--- 1 file changed, 57 insertions(+), 4 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 616bbf53..a830da64 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -50,7 +50,7 @@ def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER # η₁ = μ expectation_vector = mu - # η₂ = S + η₁ η₁ᵀ + # η₂ = S + μ μᵀ expectation_matrix = S + jnp.matmul(mu, mu.T) return {"expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix} @@ -73,7 +73,7 @@ def _expectation_elbo(posterior: AbstractPosterior, ) svgp = StochasticVI(posterior=posterior, variational_family=evg) - return svgp.elbo(train_data, build_identity(svgp.params)) + return svgp.elbo(train_data, build_identity(svgp.params), negative=True) def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: @@ -125,7 +125,7 @@ def natural_gradients( variational_family = stochastic_vi.variational_family # The ELBO under the user chosen parameterisation xi. - xi_elbo = stochastic_vi.elbo(train_data, transformations) + xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True) # The ELBO under the expectation parameterisation, L(η). expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) @@ -247,4 +247,57 @@ def loss_fn(params: dict, batch: Dataset) -> Array: return value, dL_dhyper - return nat_grads_fn, hyper_grads_fn \ No newline at end of file + return nat_grads_fn, hyper_grads_fn + + +from gpjax.abstractions import progress_bar_scan +import optax as ox +import jax + +adam = ox.adam(1e-3) +sgd = ox.sgd(1.) + +def fit_natgrads( + stochastic_vi: StochasticVI, + params: tp.Dict, + trainables: tp.Dict, + transformations: tp.Dict, + train_data: Dataset, + moment_opt = ox.sgd(1e-3), + hyper_opt = ox.adam(1e-3), + n_iters: tp.Optional[int] = 100, + log_rate: tp.Optional[int] = 10, +) -> tp.Dict: + + hyper_state = hyper_opt.init(params) + moment_state = moment_opt.init(params) + + nat_grads_fn, hyper_grads_fn = natural_gradients(stochastic_vi, train_data, transformations) + + next_batch = train_data.get_batcher() + + @progress_bar_scan(n_iters, log_rate) + def step(params_opt_state, i): + params, moment_state, hyper_state = params_opt_state + batch = next_batch() + + # Natural gradients update: + loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) + updates, moment_state = moment_opt.update(loss_gradient, moment_state, params) + params = ox.apply_updates(params, updates) + + + # Hyper-parameters update: + loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) + updates, hyper_state = hyper_opt.update(loss_gradient, hyper_state, params) + params = ox.apply_updates(params, updates) + + + params_opt_state = params, moment_state, hyper_state + + + return params_opt_state, loss_val + + (params, _, _), _ = jax.lax.scan(step, (params, moment_state, hyper_state), jnp.arange(n_iters)) + + return params \ No newline at end of file From 820c7527b42a2f13429145c0a10047aed86fa749 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Fri, 17 Jun 2022 18:44:16 +0100 Subject: [PATCH 44/66] Add rough notebook. This is rough. THERE IS A BUG SOMEWHERE. This does not use the training abstraction in natural_gradients.py --- examples/natgrads.ipynb | 306 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 306 insertions(+) create mode 100644 examples/natgrads.ipynb diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb new file mode 100644 index 00000000..2b2df0c2 --- /dev/null +++ b/examples/natgrads.ipynb @@ -0,0 +1,306 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "143ac6b9", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "98f89228", + "metadata": {}, + "source": [ + "# Natural Gradients:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10376231", + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from jax import jit, lax\n", + "import optax as ox\n", + "\n", + "import gpjax as gpx\n", + "from gpjax.natural_gradients import natural_gradients\n", + "from gpjax.abstractions import progress_bar_scan\n", + "\n", + "#Set seed for reproducibility:\n", + "import tensorflow as tf\n", + "tf.random.set_seed(42)\n", + "key = jr.PRNGKey(123)" + ] + }, + { + "cell_type": "markdown", + "id": "3b851a25", + "metadata": {}, + "source": [ + "# Dataset:" + ] + }, + { + "cell_type": "markdown", + "id": "6f7facf2", + "metadata": {}, + "source": [ + "Generate dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39d6c8e6", + "metadata": {}, + "outputs": [], + "source": [ + "n = 5000\n", + "noise = 0.2\n", + "\n", + "x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n", + "f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n", + "signal = f(x)\n", + "y = signal + jr.normal(key, shape=signal.shape) * noise\n", + "\n", + "D = gpx.Dataset(X=x, y=y)\n", + "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1)\n", + "\n", + "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" + ] + }, + { + "cell_type": "markdown", + "id": "af57fb31", + "metadata": {}, + "source": [ + "Intialise inducing points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf6533b6", + "metadata": {}, + "outputs": [], + "source": [ + "z = jnp.linspace(-5.0, 5.0, 100).reshape(-1, 1)\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "ax.plot(x, y, \"o\", alpha=0.3)\n", + "ax.plot(xtest, f(xtest))\n", + "[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "664c204b", + "metadata": {}, + "source": [ + "# Model and variational inference strategy:" + ] + }, + { + "cell_type": "markdown", + "id": "ce4de494", + "metadata": {}, + "source": [ + "Define model, variational family and variational inference strategy:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2284fbb2", + "metadata": {}, + "outputs": [], + "source": [ + "likelihood = gpx.Gaussian(num_datapoints=n)\n", + "kernel = gpx.RBF()\n", + "prior = gpx.Prior(kernel=kernel)\n", + "p = prior * likelihood\n", + "\n", + "\n", + "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "svgp = gpx.StochasticVI(posterior=p, variational_family=q)" + ] + }, + { + "cell_type": "markdown", + "id": "55e697ec", + "metadata": {}, + "source": [ + "Get default parameters and transform these to the uncontrained space:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fe96023", + "metadata": {}, + "outputs": [], + "source": [ + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)" + ] + }, + { + "cell_type": "markdown", + "id": "8969b14e", + "metadata": {}, + "source": [ + "# Natural gradients:" + ] + }, + { + "cell_type": "markdown", + "id": "e793c24f", + "metadata": {}, + "source": [ + "Define natural gradient and hyperparameter gradient functions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfab0cfc", + "metadata": {}, + "outputs": [], + "source": [ + "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)" + ] + }, + { + "cell_type": "markdown", + "id": "6fbf5e7a", + "metadata": {}, + "source": [ + "Run optimisation loop:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d88917f3", + "metadata": {}, + "outputs": [], + "source": [ + "#Optimisation example:\n", + "\n", + "n_iters = 10000\n", + "log_rate = 10\n", + "train_data = Dbatched\n", + "\n", + "\n", + "#Define optimisers:\n", + "adam = ox.adam(1e-3) #<- hyperparameters\n", + "sgd = ox.sgd(1e-3) #<- for natgrads\n", + " \n", + "\n", + "sgd_state = sgd.init(params)\n", + "adam_state = adam.init(params)\n", + "\n", + "next_batch = train_data.get_batcher()\n", + "\n", + "# Optimisation step:\n", + "@progress_bar_scan(n_iters, log_rate)\n", + "def step(params_opt_state, i):\n", + " params, sgd_state, adam_state = params_opt_state\n", + " batch = next_batch()\n", + " \n", + " # Natural gradients update:\n", + " loss_val, loss_gradient = nat_grads_fn(params, trainables, batch)\n", + " updates, opt_state = sgd.update(loss_gradient, sgd_state, params)\n", + " params = ox.apply_updates(params, updates)\n", + " \n", + " \n", + " # Hyperparameters update:\n", + " loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch)\n", + " updates, adam_state = adam.update(loss_gradient, adam_state, params)\n", + " params = ox.apply_updates(params, updates)\n", + " \n", + " \n", + " params_opt_state = params, sgd_state, adam_state\n", + " \n", + " \n", + " return params_opt_state, loss_val\n", + " \n", + " \n", + "# Optimisation loop:\n", + "(params, _, _), _ = lax.scan(step, (params, sgd_state, adam_state), jnp.arange(n_iters))" + ] + }, + { + "cell_type": "markdown", + "id": "fbcdd41c", + "metadata": {}, + "source": [ + "Plot results:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cff40778", + "metadata": {}, + "outputs": [], + "source": [ + "learned_params = gpx.transform(params, constrainers)\n", + "\n", + "latent_dist = q(learned_params)(xtest)\n", + "predictive_dist = likelihood(latent_dist, learned_params)\n", + "\n", + "meanf = predictive_dist.mean()\n", + "sigma = predictive_dist.stddev()\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "ax.plot(x, y, \"o\", alpha=0.15, label=\"Training Data\", color=\"tab:gray\")\n", + "ax.plot(xtest, meanf, label=\"Posterior mean\", color=\"tab:blue\")\n", + "ax.fill_between(xtest.flatten(), meanf - sigma, meanf + sigma, alpha=0.3)\n", + "[\n", + " ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1)\n", + " for z_i in learned_params[\"variational_family\"][\"inducing_inputs\"]\n", + "]\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.7 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "7eb1cfec58eecaa2e5422163254bd25a3275ed109df9a51c3c95d775723db6f0" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 8bf35a8d6b25f12735e7665eb35decc616851dce Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Sun, 26 Jun 2022 14:20:07 +0800 Subject: [PATCH 45/66] Minimal working example complete. --- examples/natgrads.ipynb | 113 ++++++++++++----------------- gpjax/natural_gradients.py | 143 +++++++++++++++++++++---------------- 2 files changed, 130 insertions(+), 126 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 2b2df0c2..c3bf8c04 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -74,7 +74,7 @@ "y = signal + jr.normal(key, shape=signal.shape) * noise\n", "\n", "D = gpx.Dataset(X=x, y=y)\n", - "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1)\n", + "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", "\n", "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" ] @@ -94,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "z = jnp.linspace(-5.0, 5.0, 100).reshape(-1, 1)\n", + "z = jnp.linspace(-5.0, 5.0, 10).reshape(-1, 1)\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.plot(x, y, \"o\", alpha=0.3)\n", @@ -103,6 +103,26 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "13de5cd9", + "metadata": {}, + "outputs": [], + "source": [ + "likelihood = gpx.Gaussian(num_datapoints=n)\n", + "kernel = gpx.RBF()\n", + "prior = gpx.Prior(kernel=kernel)\n", + "p = prior * likelihood\n", + "\n", + "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)\n", + "\n" + ] + }, { "cell_type": "markdown", "id": "664c204b", @@ -136,24 +156,25 @@ "svgp = gpx.StochasticVI(posterior=p, variational_family=q)" ] }, - { - "cell_type": "markdown", - "id": "55e697ec", - "metadata": {}, - "source": [ - "Get default parameters and transform these to the uncontrained space:" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "3fe96023", + "id": "5190b12d", "metadata": {}, "outputs": [], "source": [ "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "params = gpx.transform(params, unconstrainers)\n", "\n", - "params = gpx.transform(params, unconstrainers)" + "loss_fn = jit(svgp.elbo(D, constrainers, negative=True))" + ] + }, + { + "cell_type": "markdown", + "id": "55e697ec", + "metadata": {}, + "source": [ + "Get default parameters and transform these to the uncontrained space:" ] }, { @@ -179,67 +200,29 @@ "metadata": {}, "outputs": [], "source": [ + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)\n", + "\n", "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)" ] }, - { - "cell_type": "markdown", - "id": "6fbf5e7a", - "metadata": {}, - "source": [ - "Run optimisation loop:" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "d88917f3", + "id": "7e9884f2", "metadata": {}, "outputs": [], "source": [ - "#Optimisation example:\n", - "\n", - "n_iters = 10000\n", - "log_rate = 10\n", - "train_data = Dbatched\n", + "learned_params = gpx.natural_gradients.fit_natgrads(svgp,\n", + " params = params,\n", + " trainables = trainables, \n", + " transformations = constrainers,\n", + " train_data = Dbatched,\n", + " n_iters = 5000\n", + ")\n", "\n", - "\n", - "#Define optimisers:\n", - "adam = ox.adam(1e-3) #<- hyperparameters\n", - "sgd = ox.sgd(1e-3) #<- for natgrads\n", - " \n", - "\n", - "sgd_state = sgd.init(params)\n", - "adam_state = adam.init(params)\n", - "\n", - "next_batch = train_data.get_batcher()\n", - "\n", - "# Optimisation step:\n", - "@progress_bar_scan(n_iters, log_rate)\n", - "def step(params_opt_state, i):\n", - " params, sgd_state, adam_state = params_opt_state\n", - " batch = next_batch()\n", - " \n", - " # Natural gradients update:\n", - " loss_val, loss_gradient = nat_grads_fn(params, trainables, batch)\n", - " updates, opt_state = sgd.update(loss_gradient, sgd_state, params)\n", - " params = ox.apply_updates(params, updates)\n", - " \n", - " \n", - " # Hyperparameters update:\n", - " loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch)\n", - " updates, adam_state = adam.update(loss_gradient, adam_state, params)\n", - " params = ox.apply_updates(params, updates)\n", - " \n", - " \n", - " params_opt_state = params, sgd_state, adam_state\n", - " \n", - " \n", - " return params_opt_state, loss_val\n", - " \n", - " \n", - "# Optimisation loop:\n", - "(params, _, _), _ = lax.scan(step, (params, sgd_state, adam_state), jnp.arange(n_iters))" + "learned_params = gpx.transform(learned_params, constrainers)" ] }, { @@ -257,8 +240,6 @@ "metadata": {}, "outputs": [], "source": [ - "learned_params = gpx.transform(params, constrainers)\n", - "\n", "latent_dist = q(learned_params)(xtest)\n", "predictive_dist = likelihood(latent_dist, learned_params)\n", "\n", @@ -279,7 +260,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.7 ('base')", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index a830da64..3209fdec 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,23 +1,35 @@ -from copy import deepcopy -from this import d import typing as tp +from copy import deepcopy + +import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -import distrax as dx from jax import lax, value_and_grad from .config import get_defaults -from .variational_families import AbstractVariationalFamily, ExpectationVariationalGaussian, NaturalVariationalGaussian -from .variational_inference import StochasticVI -from .utils import I from .gps import AbstractPosterior -from .types import Dataset, Array -from .parameters import build_identity, transform, build_trainables_false, build_trainables_true, trainable_params +from .parameters import ( + build_identity, + build_trainables_false, + build_trainables_true, + trainable_params, + transform, +) +from .types import Array, Dataset +from .utils import I +from .variational_families import ( + AbstractVariationalFamily, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, +) +from .variational_inference import StochasticVI DEFAULT_JITTER = get_defaults()["jitter"] -def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER) -> dict: +def natural_to_expectation( + natural_moments: dict, jitter: float = DEFAULT_JITTER +) -> dict: """ Converts natural parameters to expectation parameters. Args: @@ -26,11 +38,11 @@ def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER Returns: tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ - + natural_matrix = natural_moments["natural_matrix"] natural_vector = natural_moments["natural_vector"] m = natural_vector.shape[0] - + # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix S_inv += I(m) * jitter @@ -43,23 +55,27 @@ def natural_to_expectation(natural_moments: dict, jitter: float = DEFAULT_JITTER # S = CᵀC S = jnp.matmul(C.T, C) - + # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - + # η₁ = μ expectation_vector = mu - + # η₂ = S + μ μᵀ expectation_matrix = S + jnp.matmul(mu, mu.T) - - return {"expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix} + + return { + "expectation_vector": expectation_vector, + "expectation_matrix": expectation_matrix, + } -def _expectation_elbo(posterior: AbstractPosterior, - variational_family: AbstractVariationalFamily, - train_data: Dataset, - ) -> tp.Callable[[dict, Dataset], float]: +def _expectation_elbo( + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + train_data: Dataset, +) -> tp.Callable[[dict, Dataset], float]: """ Construct evidence lower bound (ELBO) for varational Gaussian under the expectation parameterisation. Args: @@ -68,9 +84,10 @@ def _expectation_elbo(posterior: AbstractPosterior, Returns: Callable: A function that computes ELBO. """ - evg = ExpectationVariationalGaussian(prior= variational_family.prior, - inducing_inputs = variational_family.inducing_inputs, - ) + evg = ExpectationVariationalGaussian( + prior=variational_family.prior, + inducing_inputs=variational_family.inducing_inputs, + ) svgp = StochasticVI(posterior=posterior, variational_family=evg) return svgp.elbo(train_data, build_identity(svgp.params), negative=True) @@ -90,6 +107,7 @@ def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: params = trainable_params(params, trainables) return params + def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: """ Stops gradients for moment parameters. @@ -109,8 +127,8 @@ def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, transformations: dict, - #bijector = tp.Optional[dx.Bijector] = Identity, #bijector: A bijector to convert between the user chosen parameterisation and the natural parameters. - ) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: + # bijector = tp.Optional[dx.Bijector] = Identity, #bijector: A bijector to convert between the user chosen parameterisation and the natural parameters. +) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: """ Computes natural gradients for variational Gaussian. Args: @@ -126,24 +144,25 @@ def natural_gradients( # The ELBO under the user chosen parameterisation xi. xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True) - + # The ELBO under the expectation parameterisation, L(η). expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) if isinstance(variational_family, NaturalVariationalGaussian): + def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: """ Computes the natural gradients of the ELBO. Args: params: A dictionary of parameters. - trainables: A dictionary of trainables. + trainables: A dictionary of trainables. batch: A Dataset. Returns: dict: A dictionary of natural gradients. """ # Transform parameters to constrained space. params = transform(params, transformations) - + # Get natural moments θ. natural_moments = params["variational_family"]["moments"] @@ -158,9 +177,11 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: def loss_fn(params: dict, batch: Dataset) -> Array: # Determine hyperparameters that should be trained. trains = deepcopy(trainables) - trains["variational_family"]["moments"] = build_trainables_true(params["variational_family"]["moments"]) - params = trainable_params(params, trains) - + trains["variational_family"]["moments"] = build_trainables_true( + params["variational_family"]["moments"] + ) + params = trainable_params(params, trains) + # Stop gradients for non-moment parameters. params = _stop_gradients_nonmoments(params) @@ -168,21 +189,25 @@ def loss_fn(params: dict, batch: Dataset) -> Array: value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) - - # This is a renaming of the gradient components to match the natural parameterisation pytree. - nat_grad = dL_dnat - nat_grad["variational_family"]["moments"] = {"natural_vector": dL_dnat["variational_family"]["moments"]["expectation_vector"], - "natural_matrix": dL_dnat["variational_family"]["moments"]["expectation_matrix"]} + nat_grad = dL_dnat + nat_grad["variational_family"]["moments"] = { + "natural_vector": dL_dnat["variational_family"]["moments"][ + "expectation_vector" + ], + "natural_matrix": dL_dnat["variational_family"]["moments"][ + "expectation_matrix" + ], + } return value, nat_grad else: - #To Do: (DD) add general parameterisation case. + # To Do: (DD) add general parameterisation case. raise NotImplementedError # BELOW is (almost working) PSUEDO CODE of what this will look like. - + # def nat_grads_fn(params: dict, batch: Dataset) -> dict: # # Transform parameters to constrained space. # params = transform(params, transformations) @@ -221,18 +246,16 @@ def loss_fn(params: dict, batch: Dataset) -> Array: # return value, nat_grads - def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: """ Computes the hyperparameter gradients of the ELBO. Args: params: A dictionary of parameters. - trainables: A dictionary of trainables. + trainables: A dictionary of trainables. batch: A Dataset. Returns: dict: A dictionary of hyperparameter gradients. """ - # Transform parameters to constrained space. - params = transform(params, transformations) def loss_fn(params: dict, batch: Dataset) -> Array: # Determine hyperparameters that should be trained. @@ -243,19 +266,18 @@ def loss_fn(params: dict, batch: Dataset) -> Array: return xi_elbo(params, batch) - value, dL_dhyper = value_and_grad(loss_fn)(params, batch) + value, dL_dhyper = value_and_grad(loss_fn)(params, batch) - return value, dL_dhyper + return value, dL_dhyper return nat_grads_fn, hyper_grads_fn -from gpjax.abstractions import progress_bar_scan -import optax as ox import jax +import optax as ox + +from gpjax.abstractions import progress_bar_scan -adam = ox.adam(1e-3) -sgd = ox.sgd(1.) def fit_natgrads( stochastic_vi: StochasticVI, @@ -263,8 +285,8 @@ def fit_natgrads( trainables: tp.Dict, transformations: tp.Dict, train_data: Dataset, - moment_opt = ox.sgd(1e-3), - hyper_opt = ox.adam(1e-3), + moment_opt=ox.sgd(1.0), + hyper_opt=ox.adam(1e-3), n_iters: tp.Optional[int] = 100, log_rate: tp.Optional[int] = 10, ) -> tp.Dict: @@ -272,7 +294,9 @@ def fit_natgrads( hyper_state = hyper_opt.init(params) moment_state = moment_opt.init(params) - nat_grads_fn, hyper_grads_fn = natural_gradients(stochastic_vi, train_data, transformations) + nat_grads_fn, hyper_grads_fn = natural_gradients( + stochastic_vi, train_data, transformations + ) next_batch = train_data.get_batcher() @@ -280,24 +304,23 @@ def fit_natgrads( def step(params_opt_state, i): params, moment_state, hyper_state = params_opt_state batch = next_batch() - + # Natural gradients update: loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) updates, moment_state = moment_opt.update(loss_gradient, moment_state, params) params = ox.apply_updates(params, updates) - - + # Hyper-parameters update: loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) updates, hyper_state = hyper_opt.update(loss_gradient, hyper_state, params) params = ox.apply_updates(params, updates) - - + params_opt_state = params, moment_state, hyper_state - - + return params_opt_state, loss_val - (params, _, _), _ = jax.lax.scan(step, (params, moment_state, hyper_state), jnp.arange(n_iters)) + (params, _, _), _ = jax.lax.scan( + step, (params, moment_state, hyper_state), jnp.arange(n_iters) + ) - return params \ No newline at end of file + return params From 610338ccdebd7ea723cb954cd5a81cee4d552efe Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Sun, 26 Jun 2022 14:34:50 +0800 Subject: [PATCH 46/66] Update natgrads.ipynb --- examples/natgrads.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index c3bf8c04..8371069d 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -38,7 +38,7 @@ "\n", "#Set seed for reproducibility:\n", "import tensorflow as tf\n", - "tf.random.set_seed(42)\n", + "tf.random.set_seed(4)\n", "key = jr.PRNGKey(123)" ] }, @@ -94,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "z = jnp.linspace(-5.0, 5.0, 10).reshape(-1, 1)\n", + "z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.plot(x, y, \"o\", alpha=0.3)\n", From c9ee548814bf29b246a7bbdaff854e86b686c13e Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Wed, 20 Jul 2022 17:39:29 +0100 Subject: [PATCH 47/66] Rebase master. --- gpjax/natural_gradients.py | 9 +- gpjax/variational_families.py | 157 +++++++++++++++++------------ tests/test_variational_families.py | 91 +++++++++++------ 3 files changed, 157 insertions(+), 100 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 3209fdec..f8a4244c 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import jax.scipy as jsp from jax import lax, value_and_grad +from jaxtyping import f64 from .config import get_defaults from .gps import AbstractPosterior @@ -15,7 +16,7 @@ trainable_params, transform, ) -from .types import Array, Dataset +from .types import Dataset from .utils import I from .variational_families import ( AbstractVariationalFamily, @@ -174,7 +175,7 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: expectation_params["variational_family"]["moments"] = expectation_moments # Compute gradient ∂L/∂η: - def loss_fn(params: dict, batch: Dataset) -> Array: + def loss_fn(params: dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. trains = deepcopy(trainables) trains["variational_family"]["moments"] = build_trainables_true( @@ -230,7 +231,7 @@ def loss_fn(params: dict, batch: Dataset) -> Array: # expectation_params["variational_family"]["moments"] = expectation_moments # # Compute gradient ∂L/∂η: - # def loss_fn(params: dict, batch: Dataset) -> Array: + # def loss_fn(params: dict, batch: Dataset) -> f64["1"]: # # Determine hyperparameters that should be trained. # params = trainable_params(params, trainables) @@ -257,7 +258,7 @@ def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: dict: A dictionary of hyperparameter gradients. """ - def loss_fn(params: dict, batch: Dataset) -> Array: + def loss_fn(params: dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. params = trainable_params(params, trainables) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index fd8a16e7..0198b2b6 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -43,14 +43,17 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Predict the GP's output given the input.""" raise NotImplementedError + @dataclass class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" + prior: Prior - inducing_inputs: Array + inducing_inputs: f64["N D"] name: str = "Gaussian" jitter: Optional[float] = DEFAULT_JITTER + @dataclass class VariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions. @@ -59,10 +62,10 @@ class VariationalGaussian(AbstractVariationalGaussian): and the distribution over the inducing inputs is q(u) = N(μ, S). We parameterise this over μ and sqrt with S = sqrt sqrtᵀ. """ - variational_mean: Optional[Array] = None - variational_root_covariance: Optional[Array] = None + + variational_mean: Optional[f64["N D"]] = None + variational_root_covariance: Optional[f64["N D"]] = None diag: Optional[bool] = False - def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -85,13 +88,16 @@ def __post_init__(self): def _initialise_params(self, key: jnp.DeviceArray) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( - self.prior._initialise_params(key), { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": {"variational_mean": self.variational_mean, - "variational_root_covariance": self.variational_root_covariance} + self.prior._initialise_params(key), + { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "moments": { + "variational_mean": self.variational_mean, + "variational_root_covariance": self.variational_root_covariance, + }, } - } + }, ) def prior_kl(self, params: Dict) -> Float[Array, "1"]: @@ -104,7 +110,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - Array: The KL-divergence between our variational approximation and the GP prior. + f64["1"]: The KL-divergence between our variational approximation and the GP prior. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -131,7 +137,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -197,7 +203,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - Array: The KL-divergence between our variational approximation and the GP prior. + f64["N D"]: The KL-divergence between our variational approximation and the GP prior. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -219,7 +225,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -259,14 +265,15 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: ) return predict_fn - + @dataclass class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions.""" + name: str = "Natural Gaussian" - natural_vector: Optional[Array] = None - natural_matrix: Optional[Array] = None + natural_vector: Optional[f64["N D"]] = None + natural_matrix: Optional[f64["N D"]] = None def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -280,48 +287,50 @@ def __post_init__(self): add_parameter("natural_vector", Identity) if self.natural_matrix is None: - self.natural_matrix = -.5 * I(m) + self.natural_matrix = -0.5 * I(m) add_parameter("natural_matrix", Identity) @property def params(self) -> Dict: """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" return concat_dictionaries( - self.prior.params, { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": {"natural_vector": self.natural_vector, - "natural_matrix": self.natural_matrix} + self.prior.params, + { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "moments": { + "natural_vector": self.natural_vector, + "natural_matrix": self.natural_matrix, + }, } - } - + }, ) - def prior_kl(self, params: Dict) -> Array: + def prior_kl(self, params: Dict) -> f64["1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - Array: The KL-divergence between our variational approximation and the GP prior. + f64["1"]: The KL-divergence between our variational approximation and the GP prior. """ natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - + S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter L_inv = jnp.linalg.cholesky(S_inv) C = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - + S = jnp.matmul(C.T, C) mu = jnp.matmul(S, natural_vector) S += I(m) * self.jitter sqrt = jnp.linalg.cholesky(S) - + μz = self.prior.mean_function(z, params["mean_function"]) Kzz = gram(self.prior.kernel, z, params["kernel"]) Kzz += I(m) * self.jitter @@ -332,20 +341,20 @@ def prior_kl(self, params: Dict) -> Array: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: + def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. - """ + Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + """ natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - + # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter @@ -355,7 +364,7 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: # C = L⁻¹I C = jsp.linalg.solve_triangular(L, I(m), lower=True) - + # S = CᵀC S = jnp.matmul(C.T, C) @@ -367,7 +376,7 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: Array) -> dx.Distribution: + def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -381,12 +390,16 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: # Ktz Kzz⁻¹ Cᵀ Ktz_Kzz_inv_CT = jnp.matmul(Kzz_inv_Kzt.T, C.T) - + # μt + Ktz Kzz⁻¹ (μ - μz) mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = CᵀC] - covariance = Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + jnp.matmul(Ktz_Kzz_inv_CT, Ktz_Kzz_inv_CT.T) + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Kzz_inv_CT, Ktz_Kzz_inv_CT.T) + ) return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance @@ -398,9 +411,10 @@ def predict_fn(test_inputs: Array) -> dx.Distribution: @dataclass class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions.""" + name: str = "Expectation Gaussian" - expectation_vector: Optional[Array] = None - expectation_matrix: Optional[Array] = None + expectation_vector: Optional[f64["N D"]] = None + expectation_matrix: Optional[f64["N D"]] = None def __post_init__(self): """Initialise the variational Gaussian distribution.""" @@ -421,34 +435,41 @@ def __post_init__(self): def params(self) -> Dict: """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" return concat_dictionaries( - self.prior.params, { - "variational_family": { - "inducing_inputs": self.inducing_inputs, - "moments": {"expectation_vector": self.expectation_vector, - "expectation_matrix": self.expectation_matrix} + self.prior.params, + { + "variational_family": { + "inducing_inputs": self.inducing_inputs, + "moments": { + "expectation_vector": self.expectation_vector, + "expectation_matrix": self.expectation_matrix, + }, } - } + }, ) - def prior_kl(self, params: Dict) -> Array: + def prior_kl(self, params: Dict) -> f64["1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - Array: The KL-divergence between our variational approximation and the GP prior. + f64["1"]: The KL-divergence between our variational approximation and the GP prior. """ - expectation_vector = params["variational_family"]["moments"]["expectation_vector"] - expectation_matrix = params["variational_family"]["moments"]["expectation_matrix"] + expectation_vector = params["variational_family"]["moments"][ + "expectation_vector" + ] + expectation_matrix = params["variational_family"]["moments"][ + "expectation_matrix" + ] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - + mu = expectation_vector S = expectation_matrix - jnp.matmul(mu, mu.T) S += I(m) * self.jitter sqrt = jnp.linalg.cholesky(S) - + μz = self.prior.mean_function(z, params["mean_function"]) Kzz = gram(self.prior.kernel, z, params["kernel"]) Kzz += I(m) * self.jitter @@ -459,20 +480,24 @@ def prior_kl(self, params: Dict) -> Array: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: + def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. - """ - expectation_vector = params["variational_family"]["moments"]["expectation_vector"] - expectation_matrix = params["variational_family"]["moments"]["expectation_matrix"] + Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + """ + expectation_vector = params["variational_family"]["moments"][ + "expectation_vector" + ] + expectation_matrix = params["variational_family"]["moments"][ + "expectation_matrix" + ] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - + # μ = η₁ mu = expectation_vector @@ -488,26 +513,30 @@ def predict(self, params: dict) -> Callable[[Array], dx.Distribution]: Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: Array) -> dx.Distribution: + def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) μt = self.prior.mean_function(t, params["mean_function"]) - + # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt , lower=False) + Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) # Ktz Kzz⁻¹ sqrt Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) - + # μt + Ktz Kzz⁻¹ (μ - μz) mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] - covariance = Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt ) + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) + covariance = ( + Ktt + - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) + + jnp.matmul(Ktz_Kzz_inv_sqrt, Ktz_Kzz_inv_sqrt.T) + ) return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance @@ -555,7 +584,7 @@ def predict( Args: params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[Array], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ x, y = train_data.X, train_data.y diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 44e98360..d13fb1df 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -1,7 +1,6 @@ import typing as tp import distrax as dx - import jax.numpy as jnp import jax.random as jr import pytest @@ -57,15 +56,27 @@ def test_variational_gaussian(diag, n_inducing, n_test, whiten): assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() assert "variational_mean" in params["variational_family"]["moments"].keys() - assert "variational_root_covariance" in params["variational_family"]["moments"].keys() + assert ( + "variational_root_covariance" in params["variational_family"]["moments"].keys() + ) assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["variational_mean"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["variational_root_covariance"].shape == (n_inducing, n_inducing) + assert params["variational_family"]["moments"]["variational_mean"].shape == ( + n_inducing, + 1, + ) + assert params["variational_family"]["moments"][ + "variational_root_covariance" + ].shape == (n_inducing, n_inducing) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["variational_mean"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["variational_root_covariance"], jnp.DeviceArray) + assert isinstance( + params["variational_family"]["moments"]["variational_mean"], jnp.DeviceArray + ) + assert isinstance( + params["variational_family"]["moments"]["variational_root_covariance"], + jnp.DeviceArray, + ) params = gpx.config.get_defaults() assert "variational_root_covariance" in params["transformations"].keys() @@ -159,17 +170,18 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ assert isinstance(sigma, jnp.ndarray) assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) + + +@pytest.mark.parametrize("n_test", [1, 10]) @pytest.mark.parametrize("n_inducing", [1, 10, 20]) def test_natural_variational_gaussian(n_inducing, n_test): prior = gpx.Prior(kernel=gpx.RBF()) - + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - variational_family = gpx.variational_families.NaturalVariationalGaussian( - prior=prior, - inducing_inputs=inducing_inputs + prior=prior, inducing_inputs=inducing_inputs ) # Test init @@ -182,13 +194,13 @@ def test_natural_variational_gaussian(n_inducing, n_test): n_inducing, n_inducing, ) - assert jnp.all(jnp.diag(variational_family.natural_matrix) == -.5) + assert jnp.all(jnp.diag(variational_family.natural_matrix) == -0.5) params = gpx.config.get_defaults() assert "variational_root_covariance" in params["transformations"].keys() assert "variational_mean" in params["transformations"].keys() - assert (variational_family.natural_matrix == -.5 * jnp.eye(n_inducing)).all() + assert (variational_family.natural_matrix == -0.5 * jnp.eye(n_inducing)).all() assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() # params @@ -199,22 +211,31 @@ def test_natural_variational_gaussian(n_inducing, n_test): assert "natural_matrix" in params["variational_family"]["moments"].keys() assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["natural_vector"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["natural_matrix"].shape == (n_inducing, n_inducing) + assert params["variational_family"]["moments"]["natural_vector"].shape == ( + n_inducing, + 1, + ) + assert params["variational_family"]["moments"]["natural_matrix"].shape == ( + n_inducing, + n_inducing, + ) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["natural_vector"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["natural_matrix"], jnp.DeviceArray) - + assert isinstance( + params["variational_family"]["moments"]["natural_vector"], jnp.DeviceArray + ) + assert isinstance( + params["variational_family"]["moments"]["natural_matrix"], jnp.DeviceArray + ) + params = gpx.config.get_defaults() assert "natural_vector" in params["transformations"].keys() assert "natural_matrix" in params["transformations"].keys() - assert (variational_family.natural_matrix == -.5 * jnp.eye(n_inducing)).all() + assert (variational_family.natural_matrix == -0.5 * jnp.eye(n_inducing)).all() assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() - - #Test KL + # Test KL params = variational_family.params kl = variational_family.prior_kl(params) assert isinstance(kl, jnp.ndarray) @@ -239,14 +260,12 @@ def test_natural_variational_gaussian(n_inducing, n_test): @pytest.mark.parametrize("n_inducing", [1, 10, 20]) def test_expectation_variational_gaussian(n_inducing, n_test): prior = gpx.Prior(kernel=gpx.RBF()) - + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - variational_family = gpx.variational_families.ExpectationVariationalGaussian( - prior=prior, - inducing_inputs=inducing_inputs + prior=prior, inducing_inputs=inducing_inputs ) # Test init @@ -276,13 +295,23 @@ def test_expectation_variational_gaussian(n_inducing, n_test): assert "expectation_matrix" in params["variational_family"]["moments"].keys() assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["expectation_vector"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["expectation_matrix"].shape == (n_inducing, n_inducing) + assert params["variational_family"]["moments"]["expectation_vector"].shape == ( + n_inducing, + 1, + ) + assert params["variational_family"]["moments"]["expectation_matrix"].shape == ( + n_inducing, + n_inducing, + ) assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["expectation_vector"], jnp.DeviceArray) - assert isinstance(params["variational_family"]["moments"]["expectation_matrix"], jnp.DeviceArray) - + assert isinstance( + params["variational_family"]["moments"]["expectation_vector"], jnp.DeviceArray + ) + assert isinstance( + params["variational_family"]["moments"]["expectation_matrix"], jnp.DeviceArray + ) + params = gpx.config.get_defaults() assert "expectation_vector" in params["transformations"].keys() assert "expectation_matrix" in params["transformations"].keys() @@ -290,8 +319,7 @@ def test_expectation_variational_gaussian(n_inducing, n_test): assert (variational_family.expectation_matrix == jnp.eye(n_inducing)).all() assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() - - #Test KL + # Test KL params = variational_family.params kl = variational_family.prior_kl(params) assert isinstance(kl, jnp.ndarray) @@ -310,4 +338,3 @@ def test_expectation_variational_gaussian(n_inducing, n_test): assert isinstance(sigma, jnp.ndarray) assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) - From 6fe44397769b612babfb17fc761828784efe56c3 Mon Sep 17 00:00:00 2001 From: Daniel-Dodd Date: Thu, 21 Jul 2022 16:30:30 +0100 Subject: [PATCH 48/66] Add development notebook for general --- examples/Natural Gradient General case.ipynb | 467 +++++++++++++++++++ gpjax/natural_gradients.py | 49 +- 2 files changed, 471 insertions(+), 45 deletions(-) create mode 100644 examples/Natural Gradient General case.ipynb diff --git a/examples/Natural Gradient General case.ipynb b/examples/Natural Gradient General case.ipynb new file mode 100644 index 00000000..9ea5e492 --- /dev/null +++ b/examples/Natural Gradient General case.ipynb @@ -0,0 +1,467 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "import gpjax as gpx" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from jax import jit, lax\n", + "import optax as ox\n", + "\n", + "import gpjax as gpx\n", + "from gpjax.natural_gradients import natural_gradients\n", + "from gpjax.abstractions import progress_bar_scan\n", + "\n", + "#Set seed for reproducibility:\n", + "import tensorflow as tf\n", + "tf.random.set_seed(4)\n", + "key = jr.PRNGKey(123)\n", + "\n", + "import typing as tp\n", + "from copy import deepcopy\n", + "\n", + "import distrax as dx\n", + "import jax.numpy as jnp\n", + "import jax.scipy as jsp\n", + "from jax import lax, value_and_grad, jacobian\n", + "from jaxtyping import f64\n", + "\n", + "from gpjax.config import get_defaults\n", + "from gpjax.gps import AbstractPosterior\n", + "from gpjax.parameters import (\n", + " build_identity,\n", + " build_trainables_false,\n", + " build_trainables_true,\n", + " trainable_params,\n", + " transform,\n", + ")\n", + "from gpjax.types import Dataset\n", + "from gpjax.utils import I\n", + "from gpjax.variational_families import (\n", + " AbstractVariationalFamily,\n", + " ExpectationVariationalGaussian,\n", + " NaturalVariationalGaussian,\n", + ")\n", + "from gpjax.variational_inference import StochasticVI\n", + "DEFAULT_JITTER = get_defaults()[\"jitter\"]\n", + "\n", + "\n", + "from gpjax.natural_gradients import natural_gradients, natural_to_expectation, _expectation_elbo, _stop_gradients_nonmoments, _stop_gradients_moments, fit_natgrads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset and inducing points:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n = 5000\n", + "noise = 0.2\n", + "\n", + "x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n", + "f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n", + "signal = f(x)\n", + "y = signal + jr.normal(key, shape=signal.shape) * noise\n", + "\n", + "D = gpx.Dataset(X=x, y=y)\n", + "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", + "\n", + "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "z = jnp.linspace(-5.0, 5.0, 2).reshape(-1, 1)\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "ax.plot(x, y, \"o\", alpha=0.3)\n", + "ax.plot(xtest, f(xtest))\n", + "[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Natgrads code" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def natural_gradients(\n", + " stochastic_vi: StochasticVI,\n", + " train_data: Dataset,\n", + " transformations: dict,\n", + " xi_to_nat: tp.Callable[[tp.Dict], tp.Dict],\n", + " nat_to_xi: tp.Callable[[tp.Dict], tp.Dict],\n", + ") -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]:\n", + " \"\"\"\n", + " Computes natural gradients for variational Gaussian.\n", + " Args:\n", + " posterior: An instance of AbstractPosterior.\n", + " variational_family: An instance of AbstractVariationalFamily.\n", + " train_data: A Dataset.\n", + " transformations: A dictionary of transformations.\n", + " Returns:\n", + " Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively.\n", + " \"\"\"\n", + " posterior = stochastic_vi.posterior\n", + " variational_family = stochastic_vi.variational_family\n", + "\n", + " # The ELBO under the user chosen parameterisation xi.\n", + " xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True)\n", + "\n", + " # The ELBO under the expectation parameterisation, L(η).\n", + " expectation_elbo = _expectation_elbo(posterior, variational_family, train_data)\n", + "\n", + " if isinstance(variational_family, NaturalVariationalGaussian):\n", + "\n", + " def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", + " \"\"\"\n", + " Computes the natural gradients of the ELBO.\n", + " Args:\n", + " params: A dictionary of parameters.\n", + " trainables: A dictionary of trainables.\n", + " batch: A Dataset.\n", + " Returns:\n", + " dict: A dictionary of natural gradients.\n", + " \"\"\"\n", + " # Transform parameters to constrained space.\n", + " params = transform(params, transformations)\n", + "\n", + " # Get natural moments θ.\n", + " natural_moments = params[\"variational_family\"][\"moments\"]\n", + "\n", + " # Get expectation moments η.\n", + " expectation_moments = natural_to_expectation(natural_moments)\n", + "\n", + " # Full params with expectation moments.\n", + " expectation_params = deepcopy(params)\n", + " expectation_params[\"variational_family\"][\"moments\"] = expectation_moments\n", + "\n", + " # Compute gradient ∂L/∂η:\n", + " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", + " # Determine hyperparameters that should be trained.\n", + " trains = deepcopy(trainables)\n", + " trains[\"variational_family\"][\"moments\"] = build_trainables_true(\n", + " params[\"variational_family\"][\"moments\"]\n", + " )\n", + " params = trainable_params(params, trains)\n", + "\n", + " # Stop gradients for non-moment parameters.\n", + " params = _stop_gradients_nonmoments(params)\n", + "\n", + " return expectation_elbo(params, batch)\n", + "\n", + " value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch)\n", + "\n", + " # This is a renaming of the gradient components to match the natural parameterisation pytree.\n", + " natural_gradient = dL_dnat\n", + " natural_gradient[\"variational_family\"][\"moments\"] = {\n", + " \"natural_vector\": dL_dnat[\"variational_family\"][\"moments\"][\n", + " \"expectation_vector\"\n", + " ],\n", + " \"natural_matrix\": dL_dnat[\"variational_family\"][\"moments\"][\n", + " \"expectation_matrix\"\n", + " ],\n", + " }\n", + "\n", + " return value, natural_gradient\n", + "\n", + " else:\n", + "\n", + " def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", + " # Transform parameters to constrained space.\n", + " params = transform(params, transformations)\n", + "\n", + " # Get natural moments θ.\n", + " natural_moments = xi_to_nat(params[\"variational_family\"][\"moments\"])\n", + "\n", + " # Get expectation moments η.\n", + " expectation_moments = natural_to_expectation(natural_moments)\n", + "\n", + " # Gradient function ∂ξ/∂θ:\n", + " dxi_dnat = jacobian(nat_to_xi)(natural_moments)\n", + "\n", + " # Full params with expectation moments.\n", + " expectation_params = deepcopy(params)\n", + " expectation_params[\"variational_family\"][\"moments\"] = expectation_moments\n", + "\n", + " # Compute gradient ∂L/∂η:\n", + " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", + " # Determine hyperparameters that should be trained.\n", + " trains = deepcopy(trainables)\n", + " trains[\"variational_family\"][\"moments\"] = build_trainables_true(\n", + " params[\"variational_family\"][\"moments\"]\n", + " )\n", + " params = trainable_params(params, trains)\n", + "\n", + " # Stop gradients for non-moment parameters.\n", + " params = _stop_gradients_nonmoments(params)\n", + "\n", + " return expectation_elbo(params, batch)\n", + "\n", + " value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch)\n", + "\n", + " \n", + " # The issue is combining: ∂ξ/∂θ ∂L/∂η\n", + " natural_gradient = None\n", + " \n", + " return value, natural_gradient\n", + "\n", + " def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", + " \"\"\"\n", + " Computes the hyperparameter gradients of the ELBO.\n", + " Args:\n", + " params: A dictionary of parameters.\n", + " trainables: A dictionary of trainables.\n", + " batch: A Dataset.\n", + " Returns:\n", + " dict: A dictionary of hyperparameter gradients.\n", + " \"\"\"\n", + "\n", + " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", + " # Determine hyperparameters that should be trained.\n", + " params = trainable_params(params, trainables)\n", + "\n", + " # Stop gradients for the moment parameters.\n", + " params = _stop_gradients_moments(params)\n", + "\n", + " return xi_elbo(params, batch)\n", + "\n", + " value, dL_dhyper = value_and_grad(loss_fn)(params, batch)\n", + "\n", + " return value, dL_dhyper\n", + "\n", + " return nat_grads_fn, hyper_grads_fn" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will consider using the expectation family as a test for computing natural gradients $\\xi = \\eta$ (though of course this simplifies in reality: $\\frac{d\\xi}{d\\theta} \\frac{d\\mathcal{L}}{d\\eta} = \\frac{d\\mathcal{L}}{d\\theta}$)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We begin by defining the bijection between $\\xi$ and $\\theta$:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def xi_to_nat(moments: dict) -> dict:\n", + " \n", + " expectation_vector = moments[\"expectation_vector\"]\n", + " expectation_matrix = moments[\"expectation_matrix\"]\n", + " \n", + " m = expectation_vector.shape[0]\n", + " \n", + " mu = expectation_vector\n", + " \n", + " S = expectation_matrix - jnp.matmul(mu, mu.T)\n", + " S += I(m) * 1e-6\n", + " \n", + " L = jnp.linalg.cholesky(S)\n", + " \n", + " L_inv = jsp.linalg.solve_triangular(L, S, lower=True)\n", + " \n", + " S_inv = jnp.matmul(L_inv.T, L_inv)\n", + " \n", + " natural_matrix = - 0.5 * S_inv\n", + " natural_vector = jnp.matmul(S_inv, mu)\n", + " \n", + " return {\"natural_matrix\": natural_matrix, \"natural_vector\": natural_vector}\n", + "\n", + "def nat_to_xi(moments: dict) -> dict:\n", + " \n", + " natural_vector = moments[\"natural_vector\"]\n", + " natural_matrix = moments[\"natural_matrix\"]\n", + " \n", + " m = natural_vector.shape[0]\n", + " \n", + " S_inv = -2 * natural_matrix\n", + " S_inv += I(m) * 1e-6\n", + " L = jnp.linalg.cholesky(S_inv)\n", + " \n", + " C = jsp.linalg.solve_triangular(L, I(m), lower=True)\n", + " S = jnp.matmul(C.T, C)\n", + " \n", + " mu = jnp.matmul(S, natural_vector)\n", + " \n", + " expectation_vector = mu\n", + " expectation_matrix = S + jnp.matmul(mu, mu.T)\n", + " \n", + " \n", + " return {\"expectation_vector\": expectation_vector, \"expectation_matrix\": expectation_matrix}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then would do:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "likelihood = gpx.Gaussian(num_datapoints=n)\n", + "kernel = gpx.RBF()\n", + "prior = gpx.Prior(kernel=kernel)\n", + "p = prior * likelihood\n", + "\n", + "\n", + "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "\n", + "q = gpx.ExpectationVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "\n", + "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", + "\n", + "\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then obtain our gradient functions as follows: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, xi_to_nat = xi_to_nat, nat_to_xi = nat_to_xi)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And evaluate them e.g., like" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nat_grads_fn(params=params, trainables=trainables, batch=D)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This gives a tuple of the loss function value and the gradient that is None for now, as we have not implemented it." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In reality, we won't see these as we have a training loop abstraction that could look something like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "learned_params = fit_natgrads(svgp,\n", + " params = params,\n", + " trainables = trainables, \n", + " transformations = constrainers,\n", + " train_data = Dbatched,\n", + " n_iters = 5000,\n", + " xi_to_nat= xi_to_nat,\n", + " nat_to_xi = nat_to_xi\n", + ")\n", + "\n", + "learned_params = gpx.transform(learned_params, constrainers)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.7 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "7eb1cfec58eecaa2e5422163254bd25a3275ed109df9a51c3c95d775723db6f0" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index f8a4244c..b6a03db6 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -188,15 +188,15 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: return expectation_elbo(params, batch) - value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) + value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch) # This is a renaming of the gradient components to match the natural parameterisation pytree. - nat_grad = dL_dnat + nat_grad = dL_dexp nat_grad["variational_family"]["moments"] = { - "natural_vector": dL_dnat["variational_family"]["moments"][ + "natural_vector": dL_dexp["variational_family"]["moments"][ "expectation_vector" ], - "natural_matrix": dL_dnat["variational_family"]["moments"][ + "natural_matrix": dL_dexp["variational_family"]["moments"][ "expectation_matrix" ], } @@ -204,49 +204,8 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: return value, nat_grad else: - # To Do: (DD) add general parameterisation case. raise NotImplementedError - # BELOW is (almost working) PSUEDO CODE of what this will look like. - - # def nat_grads_fn(params: dict, batch: Dataset) -> dict: - # # Transform parameters to constrained space. - # params = transform(params, transformations) - - # # Stop gradients for non-moment parameters. - # params = _stop_gradients_nonmoments(params) - - # # Get natural moments θ. - # natural_moments = bijector.inverse(params["variational_family"]["moments"]) - - # # Get expectation moments η. - # expectation_moments = natural_to_expectation(natural_moments) - - # # Gradient function ∂ξ/∂θ: - # #### NEED TO STOP GRADIENTS FOR NON MOMENTS HERE!#### - # dxi_dnat = jacobian(nat_to_moments.forward)(natural_moments) - - # # Full params with expectation moments. - # expectation_params = deepcopy(params) - # expectation_params["variational_family"]["moments"] = expectation_moments - - # # Compute gradient ∂L/∂η: - # def loss_fn(params: dict, batch: Dataset) -> f64["1"]: - # # Determine hyperparameters that should be trained. - # params = trainable_params(params, trainables) - - # # Stop gradients for non-moment parameters. - # params = _stop_gradients_nonmoments(params) - - # return expectation_elbo(expectation_params, batch) - - # value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch) - - # # ∂ξ/∂θ ∂L/∂η - # nat_grads = jax.tree_multimap(lambda x, y: jnp.matmul(x.T, y), dxi_dnat, dL_dnat) - - # return value, nat_grads - def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: """ Computes the hyperparameter gradients of the ELBO. From 20473e62ea0c1b22e7e702501e51e32d66307144 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 19 Aug 2022 19:02:31 +0100 Subject: [PATCH 49/66] Update training loop. --- examples/Natural Gradient General case.ipynb | 467 ------------------- examples/natgrads.ipynb | 16 +- gpjax/natural_gradients.py | 59 ++- tests/test_natural_gradients.py | 69 ++- 4 files changed, 95 insertions(+), 516 deletions(-) delete mode 100644 examples/Natural Gradient General case.ipynb diff --git a/examples/Natural Gradient General case.ipynb b/examples/Natural Gradient General case.ipynb deleted file mode 100644 index 9ea5e492..00000000 --- a/examples/Natural Gradient General case.ipynb +++ /dev/null @@ -1,467 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import gpjax as gpx" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "import jax.random as jr\n", - "import matplotlib.pyplot as plt\n", - "from jax import jit, lax\n", - "import optax as ox\n", - "\n", - "import gpjax as gpx\n", - "from gpjax.natural_gradients import natural_gradients\n", - "from gpjax.abstractions import progress_bar_scan\n", - "\n", - "#Set seed for reproducibility:\n", - "import tensorflow as tf\n", - "tf.random.set_seed(4)\n", - "key = jr.PRNGKey(123)\n", - "\n", - "import typing as tp\n", - "from copy import deepcopy\n", - "\n", - "import distrax as dx\n", - "import jax.numpy as jnp\n", - "import jax.scipy as jsp\n", - "from jax import lax, value_and_grad, jacobian\n", - "from jaxtyping import f64\n", - "\n", - "from gpjax.config import get_defaults\n", - "from gpjax.gps import AbstractPosterior\n", - "from gpjax.parameters import (\n", - " build_identity,\n", - " build_trainables_false,\n", - " build_trainables_true,\n", - " trainable_params,\n", - " transform,\n", - ")\n", - "from gpjax.types import Dataset\n", - "from gpjax.utils import I\n", - "from gpjax.variational_families import (\n", - " AbstractVariationalFamily,\n", - " ExpectationVariationalGaussian,\n", - " NaturalVariationalGaussian,\n", - ")\n", - "from gpjax.variational_inference import StochasticVI\n", - "DEFAULT_JITTER = get_defaults()[\"jitter\"]\n", - "\n", - "\n", - "from gpjax.natural_gradients import natural_gradients, natural_to_expectation, _expectation_elbo, _stop_gradients_nonmoments, _stop_gradients_moments, fit_natgrads" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Dataset and inducing points:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n = 5000\n", - "noise = 0.2\n", - "\n", - "x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n", - "f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n", - "signal = f(x)\n", - "y = signal + jr.normal(key, shape=signal.shape) * noise\n", - "\n", - "D = gpx.Dataset(X=x, y=y)\n", - "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", - "\n", - "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "z = jnp.linspace(-5.0, 5.0, 2).reshape(-1, 1)\n", - "\n", - "fig, ax = plt.subplots(figsize=(12, 5))\n", - "ax.plot(x, y, \"o\", alpha=0.3)\n", - "ax.plot(xtest, f(xtest))\n", - "[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Natgrads code" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def natural_gradients(\n", - " stochastic_vi: StochasticVI,\n", - " train_data: Dataset,\n", - " transformations: dict,\n", - " xi_to_nat: tp.Callable[[tp.Dict], tp.Dict],\n", - " nat_to_xi: tp.Callable[[tp.Dict], tp.Dict],\n", - ") -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]:\n", - " \"\"\"\n", - " Computes natural gradients for variational Gaussian.\n", - " Args:\n", - " posterior: An instance of AbstractPosterior.\n", - " variational_family: An instance of AbstractVariationalFamily.\n", - " train_data: A Dataset.\n", - " transformations: A dictionary of transformations.\n", - " Returns:\n", - " Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively.\n", - " \"\"\"\n", - " posterior = stochastic_vi.posterior\n", - " variational_family = stochastic_vi.variational_family\n", - "\n", - " # The ELBO under the user chosen parameterisation xi.\n", - " xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True)\n", - "\n", - " # The ELBO under the expectation parameterisation, L(η).\n", - " expectation_elbo = _expectation_elbo(posterior, variational_family, train_data)\n", - "\n", - " if isinstance(variational_family, NaturalVariationalGaussian):\n", - "\n", - " def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", - " \"\"\"\n", - " Computes the natural gradients of the ELBO.\n", - " Args:\n", - " params: A dictionary of parameters.\n", - " trainables: A dictionary of trainables.\n", - " batch: A Dataset.\n", - " Returns:\n", - " dict: A dictionary of natural gradients.\n", - " \"\"\"\n", - " # Transform parameters to constrained space.\n", - " params = transform(params, transformations)\n", - "\n", - " # Get natural moments θ.\n", - " natural_moments = params[\"variational_family\"][\"moments\"]\n", - "\n", - " # Get expectation moments η.\n", - " expectation_moments = natural_to_expectation(natural_moments)\n", - "\n", - " # Full params with expectation moments.\n", - " expectation_params = deepcopy(params)\n", - " expectation_params[\"variational_family\"][\"moments\"] = expectation_moments\n", - "\n", - " # Compute gradient ∂L/∂η:\n", - " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", - " # Determine hyperparameters that should be trained.\n", - " trains = deepcopy(trainables)\n", - " trains[\"variational_family\"][\"moments\"] = build_trainables_true(\n", - " params[\"variational_family\"][\"moments\"]\n", - " )\n", - " params = trainable_params(params, trains)\n", - "\n", - " # Stop gradients for non-moment parameters.\n", - " params = _stop_gradients_nonmoments(params)\n", - "\n", - " return expectation_elbo(params, batch)\n", - "\n", - " value, dL_dnat = value_and_grad(loss_fn)(expectation_params, batch)\n", - "\n", - " # This is a renaming of the gradient components to match the natural parameterisation pytree.\n", - " natural_gradient = dL_dnat\n", - " natural_gradient[\"variational_family\"][\"moments\"] = {\n", - " \"natural_vector\": dL_dnat[\"variational_family\"][\"moments\"][\n", - " \"expectation_vector\"\n", - " ],\n", - " \"natural_matrix\": dL_dnat[\"variational_family\"][\"moments\"][\n", - " \"expectation_matrix\"\n", - " ],\n", - " }\n", - "\n", - " return value, natural_gradient\n", - "\n", - " else:\n", - "\n", - " def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", - " # Transform parameters to constrained space.\n", - " params = transform(params, transformations)\n", - "\n", - " # Get natural moments θ.\n", - " natural_moments = xi_to_nat(params[\"variational_family\"][\"moments\"])\n", - "\n", - " # Get expectation moments η.\n", - " expectation_moments = natural_to_expectation(natural_moments)\n", - "\n", - " # Gradient function ∂ξ/∂θ:\n", - " dxi_dnat = jacobian(nat_to_xi)(natural_moments)\n", - "\n", - " # Full params with expectation moments.\n", - " expectation_params = deepcopy(params)\n", - " expectation_params[\"variational_family\"][\"moments\"] = expectation_moments\n", - "\n", - " # Compute gradient ∂L/∂η:\n", - " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", - " # Determine hyperparameters that should be trained.\n", - " trains = deepcopy(trainables)\n", - " trains[\"variational_family\"][\"moments\"] = build_trainables_true(\n", - " params[\"variational_family\"][\"moments\"]\n", - " )\n", - " params = trainable_params(params, trains)\n", - "\n", - " # Stop gradients for non-moment parameters.\n", - " params = _stop_gradients_nonmoments(params)\n", - "\n", - " return expectation_elbo(params, batch)\n", - "\n", - " value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch)\n", - "\n", - " \n", - " # The issue is combining: ∂ξ/∂θ ∂L/∂η\n", - " natural_gradient = None\n", - " \n", - " return value, natural_gradient\n", - "\n", - " def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict:\n", - " \"\"\"\n", - " Computes the hyperparameter gradients of the ELBO.\n", - " Args:\n", - " params: A dictionary of parameters.\n", - " trainables: A dictionary of trainables.\n", - " batch: A Dataset.\n", - " Returns:\n", - " dict: A dictionary of hyperparameter gradients.\n", - " \"\"\"\n", - "\n", - " def loss_fn(params: dict, batch: Dataset) -> f64[\"1\"]:\n", - " # Determine hyperparameters that should be trained.\n", - " params = trainable_params(params, trainables)\n", - "\n", - " # Stop gradients for the moment parameters.\n", - " params = _stop_gradients_moments(params)\n", - "\n", - " return xi_elbo(params, batch)\n", - "\n", - " value, dL_dhyper = value_and_grad(loss_fn)(params, batch)\n", - "\n", - " return value, dL_dhyper\n", - "\n", - " return nat_grads_fn, hyper_grads_fn" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Example" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We will consider using the expectation family as a test for computing natural gradients $\\xi = \\eta$ (though of course this simplifies in reality: $\\frac{d\\xi}{d\\theta} \\frac{d\\mathcal{L}}{d\\eta} = \\frac{d\\mathcal{L}}{d\\theta}$)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We begin by defining the bijection between $\\xi$ and $\\theta$:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def xi_to_nat(moments: dict) -> dict:\n", - " \n", - " expectation_vector = moments[\"expectation_vector\"]\n", - " expectation_matrix = moments[\"expectation_matrix\"]\n", - " \n", - " m = expectation_vector.shape[0]\n", - " \n", - " mu = expectation_vector\n", - " \n", - " S = expectation_matrix - jnp.matmul(mu, mu.T)\n", - " S += I(m) * 1e-6\n", - " \n", - " L = jnp.linalg.cholesky(S)\n", - " \n", - " L_inv = jsp.linalg.solve_triangular(L, S, lower=True)\n", - " \n", - " S_inv = jnp.matmul(L_inv.T, L_inv)\n", - " \n", - " natural_matrix = - 0.5 * S_inv\n", - " natural_vector = jnp.matmul(S_inv, mu)\n", - " \n", - " return {\"natural_matrix\": natural_matrix, \"natural_vector\": natural_vector}\n", - "\n", - "def nat_to_xi(moments: dict) -> dict:\n", - " \n", - " natural_vector = moments[\"natural_vector\"]\n", - " natural_matrix = moments[\"natural_matrix\"]\n", - " \n", - " m = natural_vector.shape[0]\n", - " \n", - " S_inv = -2 * natural_matrix\n", - " S_inv += I(m) * 1e-6\n", - " L = jnp.linalg.cholesky(S_inv)\n", - " \n", - " C = jsp.linalg.solve_triangular(L, I(m), lower=True)\n", - " S = jnp.matmul(C.T, C)\n", - " \n", - " mu = jnp.matmul(S, natural_vector)\n", - " \n", - " expectation_vector = mu\n", - " expectation_matrix = S + jnp.matmul(mu, mu.T)\n", - " \n", - " \n", - " return {\"expectation_vector\": expectation_vector, \"expectation_matrix\": expectation_matrix}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We then would do:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "likelihood = gpx.Gaussian(num_datapoints=n)\n", - "kernel = gpx.RBF()\n", - "prior = gpx.Prior(kernel=kernel)\n", - "p = prior * likelihood\n", - "\n", - "\n", - "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", - "\n", - "q = gpx.ExpectationVariationalGaussian(prior=prior, inducing_inputs=z)\n", - "\n", - "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", - "\n", - "\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", - "\n", - "params = gpx.transform(params, unconstrainers)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We then obtain our gradient functions as follows: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, xi_to_nat = xi_to_nat, nat_to_xi = nat_to_xi)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And evaluate them e.g., like" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nat_grads_fn(params=params, trainables=trainables, batch=D)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This gives a tuple of the loss function value and the gradient that is None for now, as we have not implemented it." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In reality, we won't see these as we have a training loop abstraction that could look something like this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "learned_params = fit_natgrads(svgp,\n", - " params = params,\n", - " trainables = trainables, \n", - " transformations = constrainers,\n", - " train_data = Dbatched,\n", - " n_iters = 5000,\n", - " xi_to_nat= xi_to_nat,\n", - " nat_to_xi = nat_to_xi\n", - ")\n", - "\n", - "learned_params = gpx.transform(learned_params, constrainers)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.9.7 ('base')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - }, - "vscode": { - "interpreter": { - "hash": "7eb1cfec58eecaa2e5422163254bd25a3275ed109df9a51c3c95d775723db6f0" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 8371069d..45c00397 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -214,13 +214,17 @@ "metadata": {}, "outputs": [], "source": [ - "learned_params = gpx.natural_gradients.fit_natgrads(svgp,\n", + "learned_params, training_history = gpx.natural_gradients.fit_natgrads(svgp,\n", " params = params,\n", " trainables = trainables, \n", " transformations = constrainers,\n", " train_data = Dbatched,\n", - " n_iters = 5000\n", - ")\n", + " n_iters = 5000,\n", + " batch_size=100,\n", + " seed = 42,\n", + " moment_optim = ox.sgd(1.0),\n", + " hyper_optim = ox.adam(1e-3),\n", + " )\n", "\n", "learned_params = gpx.transform(learned_params, constrainers)" ] @@ -260,7 +264,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3.10.0 ('base')", "language": "python", "name": "python3" }, @@ -274,11 +278,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.10.0" }, "vscode": { "interpreter": { - "hash": "7eb1cfec58eecaa2e5422163254bd25a3275ed109df9a51c3c95d775723db6f0" + "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf" } } }, diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index b6a03db6..acc0679b 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,10 +1,11 @@ import typing as tp from copy import deepcopy -import distrax as dx import jax.numpy as jnp +import jax.random as jr import jax.scipy as jsp -from jax import lax, value_and_grad +from chex import PRNGKey +from jax import value_and_grad from jaxtyping import f64 from .config import get_defaults @@ -245,42 +246,60 @@ def fit_natgrads( trainables: tp.Dict, transformations: tp.Dict, train_data: Dataset, - moment_opt=ox.sgd(1.0), - hyper_opt=ox.adam(1e-3), + batch_size: int, + moment_optim, + hyper_optim, + seed: tp.Union[int, PRNGKey], n_iters: tp.Optional[int] = 100, log_rate: tp.Optional[int] = 10, ) -> tp.Dict: - hyper_state = hyper_opt.init(params) - moment_state = moment_opt.init(params) + hyper_state = hyper_optim.init(params) + moment_state = moment_optim.init(params) nat_grads_fn, hyper_grads_fn = natural_gradients( stochastic_vi, train_data, transformations ) - next_batch = train_data.get_batcher() + x, y, n = train_data.X, train_data.y, train_data.n + + prng = convert_seed(seed) @progress_bar_scan(n_iters, log_rate) - def step(params_opt_state, i): - params, moment_state, hyper_state = params_opt_state - batch = next_batch() + def step(carry, _): + params, moment_state, hyper_state, prng = carry - # Natural gradients update: - loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) - updates, moment_state = moment_opt.update(loss_gradient, moment_state, params) - params = ox.apply_updates(params, updates) + indicies = jr.choice(prng, n, (batch_size,), replace=True) + + batch = Dataset(X=x[indicies], y=y[indicies]) # Hyper-parameters update: loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) - updates, hyper_state = hyper_opt.update(loss_gradient, hyper_state, params) + updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) params = ox.apply_updates(params, updates) - params_opt_state = params, moment_state, hyper_state + # Natural gradients update: + loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) + updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) + params = ox.apply_updates(params, updates) + + prng, _ = jr.split(prng) - return params_opt_state, loss_val + carry = params, moment_state, hyper_state, prng + return carry, loss_val - (params, _, _), _ = jax.lax.scan( - step, (params, moment_state, hyper_state), jnp.arange(n_iters) + (params, _, _, _), history = jax.lax.scan( + step, (params, moment_state, hyper_state, prng), jnp.arange(n_iters) ) + return params, history - return params + +def convert_seed(seed: tp.Union[int, PRNGKey]) -> PRNGKey: + """Ensure that seeds type.""" + + if isinstance(seed, int): + rng = jr.PRNGKey(seed) + else: # key is of type PRNGKey + rng = seed + + return rng diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index 4fb98d8b..bb0499c0 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -1,23 +1,28 @@ -import pytest -import jax -import jax.numpy as jnp -from gpjax.natural_gradients import natural_to_expectation, _stop_gradients_nonmoments, _stop_gradients_moments, _expectation_elbo, natural_gradients import typing as tp -import gpjax as gpx -import jax.random as jr -from gpjax.parameters import recursive_items + +import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt -from jax import jit import optax as ox +import pytest +import tensorflow as tf +from jax import jit import gpjax as gpx -import tensorflow as tf +from gpjax.natural_gradients import ( + _expectation_elbo, + _stop_gradients_moments, + _stop_gradients_nonmoments, + natural_gradients, + natural_to_expectation, +) +from gpjax.parameters import recursive_items tf.random.set_seed(42) key = jr.PRNGKey(123) + @pytest.mark.parametrize("dim", [1, 2, 3]) def test_natural_to_expectation(dim): """ @@ -29,17 +34,26 @@ def test_natural_to_expectation(dim): tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ - natural_matrix = -.5 * jnp.eye(dim) + natural_matrix = -0.5 * jnp.eye(dim) natural_vector = jnp.zeros((dim, 1)) - - natural_moments = {"natural_matrix": natural_matrix, "natural_vector": natural_vector} + + natural_moments = { + "natural_matrix": natural_matrix, + "natural_vector": natural_vector, + } expectation_moments = natural_to_expectation(natural_moments, jitter=1e-6) assert "expectation_vector" in expectation_moments.keys() assert "expectation_matrix" in expectation_moments.keys() - assert expectation_moments["expectation_vector"].shape == natural_moments["natural_vector"].shape - assert expectation_moments["expectation_matrix"].shape == natural_moments["natural_matrix"].shape + assert ( + expectation_moments["expectation_vector"].shape + == natural_moments["natural_vector"].shape + ) + assert ( + expectation_moments["expectation_matrix"].shape + == natural_moments["natural_matrix"].shape + ) def get_data_and_gp(n_datapoints): @@ -52,15 +66,18 @@ def get_data_and_gp(n_datapoints): post = p * lik return D, post, p + @pytest.mark.parametrize("jit_fns", [True, False]) def test_expectation_elbo(jit_fns): """ Tests the expectation ELBO. """ D, posterior, prior = get_data_and_gp(10) - + z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - variational_family = gpx.variational_families.ExpectationVariationalGaussian(prior = prior, inducing_inputs=z) + variational_family = gpx.variational_families.ExpectationVariationalGaussian( + prior=prior, inducing_inputs=z + ) svgp = gpx.StochasticVI(posterior=posterior, variational_family=variational_family) @@ -83,23 +100,25 @@ def test_expectation_elbo(jit_fns): assert len(grads) == len(params) - # def test_stop_gradients_nonmoments(): # pass - + # def test_stop_gradients_moments(): # pass + def test_natural_gradients(): """ Tests the expectation ELBO. """ D, p, prior = get_data_and_gp(10) - + z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1) + Dbatched = ( + D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1) + ) likelihood = gpx.Gaussian(num_datapoints=D.n) prior = gpx.Prior(kernel=gpx.RBF()) @@ -126,11 +145,15 @@ def test_natural_gradients(): assert isinstance(hyper_grads, tp.Dict) # Need to check moments are zero in hyper_grads: - assert jnp.array([ (v == 0.).all() for v in hyper_grads["variational_family"]["moments"].values()]).all() + assert jnp.array( + [ + (v == 0.0).all() + for v in hyper_grads["variational_family"]["moments"].values() + ] + ).all() # Check non-moments are zero in nat_grads: - d = jax.tree_map(lambda x: (x==0.).all(), nat_grads) + d = jax.tree_map(lambda x: (x == 0.0).all(), nat_grads) d["variational_family"]["moments"] = True assert jnp.array([v1 == True for k, v1, v2 in recursive_items(d, d)]).all() - From 7f7c4245d8b36ef3a31520347067bc6df4d83324 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 19 Aug 2022 22:43:03 +0100 Subject: [PATCH 50/66] Update training loop. Add collapsed bound and natural gradient relationship in the notebook. --- examples/natgrads.ipynb | 161 ++++++++++++++++++++++++++++++++++--- gpjax/natural_gradients.py | 40 +++------ 2 files changed, 161 insertions(+), 40 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 45c00397..41e898e7 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -1,22 +1,23 @@ { "cells": [ { - "cell_type": "code", - "execution_count": null, - "id": "143ac6b9", + "cell_type": "markdown", + "id": "98f89228", "metadata": {}, - "outputs": [], "source": [ - "%load_ext autoreload\n", - "%autoreload 2" + "# Natural Gradients:" ] }, { "cell_type": "markdown", - "id": "98f89228", + "id": "02dcd16f", "metadata": {}, "source": [ - "# Natural Gradients:" + "In this notebook we demonstrate how to implement natural gradients. \n", + "\n", + "As well explained in Salimbeni et al. (2018),\n", + "\n", + "\"The ordinary gradient turns out to be an unnatural direction to follow for variational inference since we are optimizing a distribution, rather than a set of pa- rameters directly. One way to define the gradient is the direction that achieves maximum change subject to a perturbation within a small euclidean ball. To see why the euclidean distance is an unnatural metric for probability distributions, consider the two Gaussians $\\mathcal{N}(0, 0.1)$ and $\\mathcal{N}(0, 0.2)$, compared to $\\mathcal{N}(0, 1000.1)$ and $\\mathcal{N}N(0,1000.2)$.\"" ] }, { @@ -202,9 +203,7 @@ "source": [ "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", "\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n", - "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)" + "params = gpx.transform(params, unconstrainers)" ] }, { @@ -221,7 +220,7 @@ " train_data = Dbatched,\n", " n_iters = 5000,\n", " batch_size=100,\n", - " seed = 42,\n", + " key = jr.PRNGKey(42),\n", " moment_optim = ox.sgd(1.0),\n", " hyper_optim = ox.adam(1e-3),\n", " )\n", @@ -260,6 +259,144 @@ "]\n", "plt.show()" ] + }, + { + "cell_type": "markdown", + "id": "5db1e2e3", + "metadata": {}, + "source": [ + "# Natural gradients and sparse varational Gaussian process regression:" + ] + }, + { + "cell_type": "markdown", + "id": "649d29ec", + "metadata": {}, + "source": [ + "As mentioned in Hensman et al 2013, ....\n", + "\n", + "We demonstrate this now:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0995d1f2", + "metadata": {}, + "outputs": [], + "source": [ + "n = 1000\n", + "noise = 0.2\n", + "\n", + "x = jr.uniform(key=key, minval=-5.0, maxval=5.0, shape=(n,)).sort().reshape(-1, 1)\n", + "f = lambda x: jnp.sin(4 * x) + jnp.cos(2 * x)\n", + "signal = f(x)\n", + "y = signal + jr.normal(key, shape=signal.shape) * noise\n", + "\n", + "D = gpx.Dataset(X=x, y=y)\n", + "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", + "\n", + "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ec554e0", + "metadata": {}, + "outputs": [], + "source": [ + "z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n", + "\n", + "fig, ax = plt.subplots(figsize=(12, 5))\n", + "ax.plot(x, y, \"o\", alpha=0.3)\n", + "ax.plot(xtest, f(xtest))\n", + "[ax.axvline(x=z_i, color=\"black\", alpha=0.3, linewidth=1) for z_i in z]\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eee8115a", + "metadata": {}, + "outputs": [], + "source": [ + "likelihood = gpx.Gaussian(num_datapoints=n)\n", + "kernel = gpx.RBF()\n", + "prior = gpx.Prior(kernel=kernel)\n", + "p = prior * likelihood" + ] + }, + { + "cell_type": "markdown", + "id": "6640c071", + "metadata": {}, + "source": [ + "We begin with natgrads:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "078e03c4", + "metadata": {}, + "outputs": [], + "source": [ + "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)\n", + "\n", + "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)\n", + "\n", + "moment_optim = ox.sgd(1.0)\n", + "\n", + "moment_state = moment_optim.init(params)\n", + "\n", + "# Natural gradients update:\n", + "loss_val, loss_gradient = nat_grads_fn(params, trainables, D)\n", + "print(loss_val)\n", + "\n", + "updates, moment_state = moment_optim.update(loss_gradient, moment_state, params)\n", + "params = ox.apply_updates(params, updates)\n", + "\n", + "loss_val, _ = nat_grads_fn(params, trainables, D)\n", + "\n", + "print(loss_val)" + ] + }, + { + "cell_type": "markdown", + "id": "c7c16824", + "metadata": {}, + "source": [ + "Let us now run it for SGPR:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6482af13", + "metadata": {}, + "outputs": [], + "source": [ + "from gpjax.parameters import build_identity\n", + "\n", + "q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)\n", + "sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)\n", + "\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "\n", + "params = gpx.transform(params, unconstrainers)\n", + "\n", + "loss_fn = sgpr.elbo(D, constrainers, negative=True)\n", + "\n", + "loss_val = loss_fn(params)\n", + "\n", + "print(loss_val)" + ] } ], "metadata": { diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index acc0679b..94e2d600 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,13 +1,15 @@ import typing as tp from copy import deepcopy +import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp -from chex import PRNGKey +import optax as ox from jax import value_and_grad from jaxtyping import f64 +from .abstractions import progress_bar_scan from .config import get_defaults from .gps import AbstractPosterior from .parameters import ( @@ -17,7 +19,7 @@ trainable_params, transform, ) -from .types import Dataset +from .types import Dataset, PRNGKeyType from .utils import I from .variational_families import ( AbstractVariationalFamily, @@ -91,8 +93,9 @@ def _expectation_elbo( inducing_inputs=variational_family.inducing_inputs, ) svgp = StochasticVI(posterior=posterior, variational_family=evg) + identity_transformation = build_identity(svgp.params) - return svgp.elbo(train_data, build_identity(svgp.params), negative=True) + return svgp.elbo(train_data, identity_transformation, negative=True) def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: @@ -234,12 +237,6 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: return nat_grads_fn, hyper_grads_fn -import jax -import optax as ox - -from gpjax.abstractions import progress_bar_scan - - def fit_natgrads( stochastic_vi: StochasticVI, params: tp.Dict, @@ -249,7 +246,7 @@ def fit_natgrads( batch_size: int, moment_optim, hyper_optim, - seed: tp.Union[int, PRNGKey], + key: PRNGKeyType, n_iters: tp.Optional[int] = 100, log_rate: tp.Optional[int] = 10, ) -> tp.Dict: @@ -263,13 +260,11 @@ def fit_natgrads( x, y, n = train_data.X, train_data.y, train_data.n - prng = convert_seed(seed) - @progress_bar_scan(n_iters, log_rate) def step(carry, _): - params, moment_state, hyper_state, prng = carry + params, moment_state, hyper_state, current_key = carry - indicies = jr.choice(prng, n, (batch_size,), replace=True) + indicies = jr.choice(current_key, n, (batch_size,), replace=True) batch = Dataset(X=x[indicies], y=y[indicies]) @@ -283,23 +278,12 @@ def step(carry, _): updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) params = ox.apply_updates(params, updates) - prng, _ = jr.split(prng) + _, new_key = jr.split(current_key) - carry = params, moment_state, hyper_state, prng + carry = params, moment_state, hyper_state, new_key return carry, loss_val (params, _, _, _), history = jax.lax.scan( - step, (params, moment_state, hyper_state, prng), jnp.arange(n_iters) + step, (params, moment_state, hyper_state, key), jnp.arange(n_iters) ) return params, history - - -def convert_seed(seed: tp.Union[int, PRNGKey]) -> PRNGKey: - """Ensure that seeds type.""" - - if isinstance(seed, int): - rng = jr.PRNGKey(seed) - else: # key is of type PRNGKey - rng = seed - - return rng From 48fcfa86fbf363a641730ad397b316217249f42e Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 19 Aug 2022 22:49:05 +0100 Subject: [PATCH 51/66] Update test_natural_gradients.py --- tests/test_natural_gradients.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index bb0499c0..214e4ee0 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -3,7 +3,6 @@ import jax import jax.numpy as jnp import jax.random as jr -import matplotlib.pyplot as plt import optax as ox import pytest import tensorflow as tf From 1e8a8e033b6c9697e484cf820be4580a9d547696 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 23 Aug 2022 12:47:59 +0100 Subject: [PATCH 52/66] Fix variational families. --- gpjax/natural_gradients.py | 28 +++++++++++++--------------- gpjax/parameters.py | 6 ++++-- gpjax/variational_families.py | 18 ++++++++---------- tests/test_natural_gradients.py | 6 ------ tests/test_variational_families.py | 8 ++++---- 5 files changed, 29 insertions(+), 37 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 94e2d600..7e877802 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -9,7 +9,7 @@ from jax import value_and_grad from jaxtyping import f64 -from .abstractions import progress_bar_scan +from .abstractions import InferenceState, get_batch, progress_bar_scan from .config import get_defaults from .gps import AbstractPosterior from .parameters import ( @@ -93,7 +93,7 @@ def _expectation_elbo( inducing_inputs=variational_family.inducing_inputs, ) svgp = StochasticVI(posterior=posterior, variational_family=evg) - identity_transformation = build_identity(svgp.params) + identity_transformation = build_identity(svgp._initialise_params(jr.PRNGKey(123))) return svgp.elbo(train_data, identity_transformation, negative=True) @@ -258,15 +258,15 @@ def fit_natgrads( stochastic_vi, train_data, transformations ) - x, y, n = train_data.X, train_data.y, train_data.n + keys = jax.random.split(key, n_iters) + iter_nums = jnp.arange(n_iters) @progress_bar_scan(n_iters, log_rate) - def step(carry, _): - params, moment_state, hyper_state, current_key = carry + def step(carry, iter_num__and__key): + iter_num, key = iter_num__and__key + params, hyper_state, moment_state = carry - indicies = jr.choice(current_key, n, (batch_size,), replace=True) - - batch = Dataset(X=x[indicies], y=y[indicies]) + batch = get_batch(train_data, batch_size, key) # Hyper-parameters update: loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) @@ -276,14 +276,12 @@ def step(carry, _): # Natural gradients update: loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) - params = ox.apply_updates(params, updates) - - _, new_key = jr.split(current_key) - carry = params, moment_state, hyper_state, new_key + carry = params, hyper_state, moment_state return carry, loss_val - (params, _, _, _), history = jax.lax.scan( - step, (params, moment_state, hyper_state, key), jnp.arange(n_iters) + (params, _, _), history = jax.lax.scan( + step, (params, hyper_state, moment_state), (iter_nums, keys) ) - return params, history + inf_state = InferenceState(params=params, history=history) + return inf_state diff --git a/gpjax/parameters.py b/gpjax/parameters.py index f09e2479..6f1bd28c 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -14,7 +14,7 @@ from .types import PRNGKeyType from .utils import merge_dictionaries -Identity = dx.Lambda(forward = lambda x: x, inverse = lambda x: x) +Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) ################################ @@ -51,6 +51,7 @@ def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: ) return state + def _validate_kwargs(kwargs, params): for k, v in kwargs.items(): if k not in params.keys(): @@ -171,8 +172,9 @@ def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) + def build_identity(params: tp.Dict) -> tp.Dict: - """" + """ " Args: params (tp.Dict): The parameter set for which trainable statuses should be derived from. diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 0198b2b6..e70fe595 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -12,7 +12,7 @@ from .gps import Prior from .kernels import cross_covariance, gram from .likelihoods import AbstractLikelihood, Gaussian -from .types import Dataset +from .types import Dataset, PRNGKeyType from .utils import I, concat_dictionaries DEFAULT_JITTER = get_defaults()["jitter"] @@ -34,7 +34,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.predict(*args, **kwargs) @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix.""" raise NotImplementedError @@ -85,7 +85,7 @@ def __post_init__(self): else: add_parameter("variational_root_covariance", FillTriangular) - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), @@ -290,11 +290,10 @@ def __post_init__(self): self.natural_matrix = -0.5 * I(m) add_parameter("natural_matrix", Identity) - @property - def params(self) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" return concat_dictionaries( - self.prior.params, + self.prior._initialise_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, @@ -431,11 +430,10 @@ def __post_init__(self): self.expectation_matrix = I(m) add_parameter("expectation_matrix", Identity) - @property - def params(self) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" return concat_dictionaries( - self.prior.params, + self.prior._initialise_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, @@ -565,7 +563,7 @@ def __post_init__(self): if not isinstance(self.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index 214e4ee0..f16a14ba 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -3,22 +3,16 @@ import jax import jax.numpy as jnp import jax.random as jr -import optax as ox import pytest -import tensorflow as tf -from jax import jit import gpjax as gpx from gpjax.natural_gradients import ( _expectation_elbo, - _stop_gradients_moments, - _stop_gradients_nonmoments, natural_gradients, natural_to_expectation, ) from gpjax.parameters import recursive_items -tf.random.set_seed(42) key = jr.PRNGKey(123) diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index d13fb1df..1f0eec6e 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -204,7 +204,7 @@ def test_natural_variational_gaussian(n_inducing, n_test): assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() # params - params = variational_family.params + params = variational_family._initialise_params(jr.PRNGKey(123)) assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() assert "natural_vector" in params["variational_family"]["moments"].keys() @@ -236,7 +236,7 @@ def test_natural_variational_gaussian(n_inducing, n_test): assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() # Test KL - params = variational_family.params + params = variational_family._initialise_params(jr.PRNGKey(123)) kl = variational_family.prior_kl(params) assert isinstance(kl, jnp.ndarray) @@ -288,7 +288,7 @@ def test_expectation_variational_gaussian(n_inducing, n_test): assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() # params - params = variational_family.params + params = variational_family._initialise_params(jr.PRNGKey(123)) assert isinstance(params, dict) assert "inducing_inputs" in params["variational_family"].keys() assert "expectation_vector" in params["variational_family"]["moments"].keys() @@ -320,7 +320,7 @@ def test_expectation_variational_gaussian(n_inducing, n_test): assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() # Test KL - params = variational_family.params + params = variational_family._initialise_params(jr.PRNGKey(123)) kl = variational_family.prior_kl(params) assert isinstance(kl, jnp.ndarray) From c265b9f6c977a9ad7930e3f68ff9c7f96dea7ded Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 23 Aug 2022 13:09:19 +0100 Subject: [PATCH 53/66] Update training loop and notebook. --- examples/natgrads.ipynb | 62 ++++++++++----------------------- gpjax/natural_gradients.py | 1 + tests/test_natural_gradients.py | 20 +++++------ 3 files changed, 29 insertions(+), 54 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 41e898e7..de959de9 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -34,7 +34,6 @@ "import optax as ox\n", "\n", "import gpjax as gpx\n", - "from gpjax.natural_gradients import natural_gradients\n", "from gpjax.abstractions import progress_bar_scan\n", "\n", "#Set seed for reproducibility:\n", @@ -75,8 +74,6 @@ "y = signal + jr.normal(key, shape=signal.shape) * noise\n", "\n", "D = gpx.Dataset(X=x, y=y)\n", - "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", - "\n", "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" ] }, @@ -104,26 +101,6 @@ "plt.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "13de5cd9", - "metadata": {}, - "outputs": [], - "source": [ - "likelihood = gpx.Gaussian(num_datapoints=n)\n", - "kernel = gpx.RBF()\n", - "prior = gpx.Prior(kernel=kernel)\n", - "p = prior * likelihood\n", - "\n", - "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", - "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", - "\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n" - ] - }, { "cell_type": "markdown", "id": "664c204b", @@ -164,7 +141,7 @@ "metadata": {}, "outputs": [], "source": [ - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", "params = gpx.transform(params, unconstrainers)\n", "\n", "loss_fn = jit(svgp.elbo(D, constrainers, negative=True))" @@ -194,18 +171,6 @@ "Define natural gradient and hyperparameter gradient functions:" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "bfab0cfc", - "metadata": {}, - "outputs": [], - "source": [ - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", - "\n", - "params = gpx.transform(params, unconstrainers)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -213,17 +178,19 @@ "metadata": {}, "outputs": [], "source": [ - "learned_params, training_history = gpx.natural_gradients.fit_natgrads(svgp,\n", + "from gpjax.natural_gradients import fit_natgrads\n", + "\n", + "learned_params, training_history = fit_natgrads(svgp,\n", " params = params,\n", " trainables = trainables, \n", " transformations = constrainers,\n", - " train_data = Dbatched,\n", - " n_iters = 5000,\n", + " train_data = D,\n", + " n_iters = 10000,\n", " batch_size=100,\n", " key = jr.PRNGKey(42),\n", " moment_optim = ox.sgd(1.0),\n", " hyper_optim = ox.adam(1e-3),\n", - " )\n", + " ).unpack()\n", "\n", "learned_params = gpx.transform(learned_params, constrainers)" ] @@ -294,7 +261,6 @@ "y = signal + jr.normal(key, shape=signal.shape) * noise\n", "\n", "D = gpx.Dataset(X=x, y=y)\n", - "Dbatched = D.cache().repeat().shuffle(D.n).batch(batch_size=256).prefetch(buffer_size=1)\n", "\n", "xtest = jnp.linspace(-5.5, 5.5, 500).reshape(-1, 1)" ] @@ -343,9 +309,11 @@ "metadata": {}, "outputs": [], "source": [ + "from gpjax.natural_gradients import natural_gradients\n", + "\n", "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", "\n", "params = gpx.transform(params, unconstrainers)\n", "\n", @@ -387,7 +355,7 @@ "q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)\n", "sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)\n", "\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp)\n", + "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", "\n", "params = gpx.transform(params, unconstrainers)\n", "\n", @@ -397,6 +365,14 @@ "\n", "print(loss_val)" ] + }, + { + "cell_type": "markdown", + "id": "bdae1c03", + "metadata": {}, + "source": [ + "The discrepancy is due to the quadrature approximation." + ] } ], "metadata": { diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 7e877802..09ea5eed 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -276,6 +276,7 @@ def step(carry, iter_num__and__key): # Natural gradients update: loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) + params = ox.apply_updates(params, updates) carry = params, hyper_state, moment_state return carry, loss_val diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index f16a14ba..4a36119c 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -6,6 +6,7 @@ import pytest import gpjax as gpx +from gpjax.abstractions import get_batch from gpjax.natural_gradients import ( _expectation_elbo, natural_gradients, @@ -74,7 +75,9 @@ def test_expectation_elbo(jit_fns): svgp = gpx.StochasticVI(posterior=posterior, variational_family=variational_family) - params, _, constrainer, unconstrainer = gpx.initialise(svgp) + params, _, constrainer, unconstrainer = gpx.initialise( + svgp, jr.PRNGKey(123) + ).unpack() expectation_elbo = _expectation_elbo(posterior, variational_family, D) @@ -103,27 +106,22 @@ def test_expectation_elbo(jit_fns): def test_natural_gradients(): """ - Tests the expectation ELBO. + Tests the natural gradient and hyperparameter gradients. """ D, p, prior = get_data_and_gp(10) z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) - - Dbatched = ( - D.cache().repeat().shuffle(D.n).batch(batch_size=128).prefetch(buffer_size=1) - ) - - likelihood = gpx.Gaussian(num_datapoints=D.n) prior = gpx.Prior(kernel=gpx.RBF()) q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainables, constrainers, unconstrainers = gpx.initialise(svgp) + params, trainables, constrainers, unconstrainers = gpx.initialise( + svgp, jr.PRNGKey(123) + ).unpack() params = gpx.transform(params, unconstrainers) - batcher = Dbatched.get_batcher() - batch = batcher() + batch = get_batch(D, batch_size=10, key=jr.PRNGKey(42)) nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers) From 5a9f2322b549043ba5b0362a3aacd075be4f6cc9 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 23 Aug 2022 17:19:25 +0100 Subject: [PATCH 54/66] Clean variational families. This commit updates variational families and their tests. --- examples/natgrads.ipynb | 3 - gpjax/config.py | 34 ++- gpjax/natural_gradients.py | 13 +- gpjax/variational_families.py | 87 ++------ tests/test_variational_families.py | 327 ++++++++-------------------- tests/test_variational_inference.py | 35 +-- 6 files changed, 156 insertions(+), 343 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index de959de9..19e6158c 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -36,9 +36,6 @@ "import gpjax as gpx\n", "from gpjax.abstractions import progress_bar_scan\n", "\n", - "#Set seed for reproducibility:\n", - "import tensorflow as tf\n", - "tf.random.set_seed(4)\n", "key = jr.PRNGKey(123)" ] }, diff --git a/gpjax/config.py b/gpjax/config.py index c330a3da..5f6dcdac 100644 --- a/gpjax/config.py +++ b/gpjax/config.py @@ -1,6 +1,7 @@ import distrax as dx import jax.numpy as jnp import jax.random as jr +import tensorflow_probability.substrates.jax.bijectors as tfb from ml_collections import ConfigDict __config = None @@ -10,27 +11,10 @@ inverse=lambda x: jnp.log(jnp.exp(x) - 1.0), ) -# class Softplus(dx.Bijector): -# def __init__(self): -# super().__init__(event_ndims_in=0) +# TODO: Remove this once 'FillTriangular' is added to Distrax. +FillTriangular = dx.Chain([tfb.FillTriangular()]) -# def forward_and_log_det(self, x): -# softplus = lambda xx: jnp.log(1 + jnp.exp(xx)) -# y = softplus(x) -# logdet = softplus(-x) -# return y, logdet - -# def inverse_and_log_det(self, y): -# """ -# Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] -# ==> dX/dY = exp{Y} / (exp{Y} - 1) -# = 1 / (1 - exp{-Y}) -# """ -# x = jnp.log(jnp.exp(y) - 1.0) -# logdet = 1 / (1 - jnp.exp(-y)) -# return x, logdet - -Identity = dx.Lambda(forward = lambda x: x, inverse = lambda x: x) +Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) def get_defaults() -> ConfigDict: @@ -41,6 +25,7 @@ def get_defaults() -> ConfigDict: """ config = ConfigDict() config.key = jr.PRNGKey(123) + # Covariance matrix stabilising jitter config.jitter = 1e-6 @@ -48,6 +33,7 @@ def get_defaults() -> ConfigDict: config.transformations = transformations = ConfigDict() transformations.positive_transform = Softplus transformations.identity_transform = Identity + transformations.triangular_transform = FillTriangular # Default parameter transforms transformations.lengthscale = "positive_transform" @@ -58,6 +44,14 @@ def get_defaults() -> ConfigDict: transformations.latent = "identity_transform" transformations.basis_fns = "identity_transform" transformations.offset = "identity_transform" + transformations.inducing_inputs = "identity_transform" + transformations.variational_mean = "identity_transform" + transformations.variational_root_covariance = "triangular_transform" + transformations.natural_vector = "identity_transform" + transformations.natural_matrix = "identity_transform" + transformations.expectation_vector = "identity_transform" + transformations.expectation_matrix = "identity_transform" + global __config if not __config: __config = config diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 09ea5eed..1f68ac0a 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -81,18 +81,20 @@ def _expectation_elbo( train_data: Dataset, ) -> tp.Callable[[dict, Dataset], float]: """ - Construct evidence lower bound (ELBO) for varational Gaussian under the expectation parameterisation. + Construct evidence lower bound (ELBO) for variational Gaussian under the expectation parameterisation. Args: posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. Returns: - Callable: A function that computes ELBO. + Callable: A function that computes the ELBO. """ - evg = ExpectationVariationalGaussian( + expectation_vartiational_gaussian = ExpectationVariationalGaussian( prior=variational_family.prior, inducing_inputs=variational_family.inducing_inputs, ) - svgp = StochasticVI(posterior=posterior, variational_family=evg) + svgp = StochasticVI( + posterior=posterior, variational_family=expectation_vartiational_gaussian + ) identity_transformation = build_identity(svgp._initialise_params(jr.PRNGKey(123))) return svgp.elbo(train_data, identity_transformation, negative=True) @@ -132,10 +134,9 @@ def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, transformations: dict, - # bijector = tp.Optional[dx.Bijector] = Identity, #bijector: A bijector to convert between the user chosen parameterisation and the natural parameters. ) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: """ - Computes natural gradients for variational Gaussian. + Computes the gradient with respect to the natural parameters. Currently only implemented for the natural variational Gaussian family. Args: posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e70fe595..1120241f 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -4,11 +4,10 @@ import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -import tensorflow_probability.substrates.jax.bijectors as tfb from chex import dataclass from jaxtyping import Array, Float -from .config import Identity, Softplus, add_parameter, get_defaults +from .config import get_defaults from .gps import Prior from .kernels import cross_covariance, gram from .likelihoods import AbstractLikelihood, Gaussian @@ -17,13 +16,6 @@ DEFAULT_JITTER = get_defaults()["jitter"] -Diagonal = dx.Lambda( - forward=lambda x: jnp.diagflat(x), inverse=lambda x: jnp.diagonal(x) -) - -FillDiagonal = dx.Chain([Diagonal, Softplus]) -FillTriangular = dx.Chain([tfb.FillTriangular()]) - @dataclass class AbstractVariationalFamily: @@ -53,6 +45,10 @@ class AbstractVariationalGaussian(AbstractVariationalFamily): name: str = "Gaussian" jitter: Optional[float] = DEFAULT_JITTER + def __post_init__(self): + """Initialise the variational Gaussian distribution.""" + self.num_inducing = self.inducing_inputs.shape[0] + @dataclass class VariationalGaussian(AbstractVariationalGaussian): @@ -63,38 +59,18 @@ class VariationalGaussian(AbstractVariationalGaussian): """ - variational_mean: Optional[f64["N D"]] = None - variational_root_covariance: Optional[f64["N D"]] = None - diag: Optional[bool] = False - - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" - self.num_inducing = self.inducing_inputs.shape[0] - add_parameter("inducing_inputs", Identity) - - m = self.num_inducing - - if self.variational_mean is None: - self.variational_mean = jnp.zeros((m, 1)) - add_parameter("variational_mean", Identity) - - if self.variational_root_covariance is None: - self.variational_root_covariance = I(m) - if self.diag: - add_parameter("variational_root_covariance", FillDiagonal) - else: - add_parameter("variational_root_covariance", FillTriangular) - def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" + m = self.num_inducing + return concat_dictionaries( self.prior._initialise_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, "moments": { - "variational_mean": self.variational_mean, - "variational_root_covariance": self.variational_root_covariance, + "variational_mean": jnp.zeros((m, 1)), + "variational_root_covariance": I(m), }, } }, @@ -272,34 +248,20 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions.""" name: str = "Natural Gaussian" - natural_vector: Optional[f64["N D"]] = None - natural_matrix: Optional[f64["N D"]] = None - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" - self.num_inducing = self.inducing_inputs.shape[0] - add_parameter("inducing_inputs", Identity) + def _initialise_params(self, key: PRNGKeyType) -> Dict: + """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" m = self.num_inducing - if self.natural_vector is None: - self.natural_vector = jnp.zeros((m, 1)) - add_parameter("natural_vector", Identity) - - if self.natural_matrix is None: - self.natural_matrix = -0.5 * I(m) - add_parameter("natural_matrix", Identity) - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, "moments": { - "natural_vector": self.natural_vector, - "natural_matrix": self.natural_matrix, + "natural_vector": jnp.zeros((m, 1)), + "natural_matrix": -0.5 * I(m), }, } }, @@ -412,34 +374,22 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions.""" name: str = "Expectation Gaussian" - expectation_vector: Optional[f64["N D"]] = None - expectation_matrix: Optional[f64["N D"]] = None - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" + def _initialise_params(self, key: PRNGKeyType) -> Dict: + """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" + self.num_inducing = self.inducing_inputs.shape[0] - add_parameter("inducing_inputs", Identity) m = self.num_inducing - if self.expectation_vector is None: - self.expectation_vector = jnp.zeros((m, 1)) - add_parameter("expectation_vector", Identity) - - if self.expectation_matrix is None: - self.expectation_matrix = I(m) - add_parameter("expectation_matrix", Identity) - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), { "variational_family": { "inducing_inputs": self.inducing_inputs, "moments": { - "expectation_vector": self.expectation_vector, - "expectation_matrix": self.expectation_matrix, + "expectation_vector": jnp.zeros((m, 1)), + "expectation_matrix": I(m), }, } }, @@ -558,7 +508,6 @@ class CollapsedVariationalGaussian(AbstractVariationalFamily): def __post_init__(self): """Initialise the variational Gaussian distribution.""" self.num_inducing = self.inducing_inputs.shape[0] - add_parameter("inducing_inputs", Identity) if not isinstance(self.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") diff --git a/tests/test_variational_families.py b/tests/test_variational_families.py index 1f0eec6e..e6f86827 100644 --- a/tests/test_variational_families.py +++ b/tests/test_variational_families.py @@ -1,4 +1,5 @@ import typing as tp +from mimetypes import init import distrax as dx import jax.numpy as jnp @@ -6,92 +7,124 @@ import pytest import gpjax as gpx +from gpjax.variational_families import ( + AbstractVariationalFamily, + CollapsedVariationalGaussian, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, + VariationalGaussian, + WhitenedVariationalGaussian, +) def test_abstract_variational_family(): with pytest.raises(TypeError): - gpx.variational_families.AbstractVariationalFamily() + AbstractVariationalFamily() -@pytest.mark.parametrize("diag", [True, False]) -@pytest.mark.parametrize("n_test", [1, 10]) -@pytest.mark.parametrize("whiten", [True, False]) -@pytest.mark.parametrize("n_inducing", [1, 10, 20]) -def test_variational_gaussian(diag, n_inducing, n_test, whiten): - prior = gpx.Prior(kernel=gpx.RBF()) +def vector_shape(n_inducing): + """Shape of a vector with n_inducing rows and 1 column""" + return (n_inducing, 1) - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) - test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - if whiten is True: - variational_family = gpx.WhitenedVariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs, diag=diag - ) - else: - variational_family = gpx.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs, diag=diag - ) +def matrix_shape(n_inducing): + """Shape of a matrix with n_inducing rows and 1 column""" + return (n_inducing, n_inducing) - # Test init - assert variational_family.num_inducing == n_inducing - assert jnp.sum(variational_family.variational_mean) == 0.0 - assert variational_family.variational_mean.shape == (n_inducing, 1) +def vector_val(val): + """Vector of shape (n_inducing, 1) filled with val""" - assert variational_family.variational_root_covariance.shape == ( - n_inducing, - n_inducing, - ) - assert jnp.all(jnp.diag(variational_family.variational_root_covariance) == 1.0) + def vector_val_fn(n_inducing): + return val * jnp.ones(vector_shape(n_inducing)) - params = gpx.config.get_defaults() - assert "variational_root_covariance" in params["transformations"].keys() - assert "variational_mean" in params["transformations"].keys() + return vector_val_fn - assert (variational_family.variational_root_covariance == jnp.eye(n_inducing)).all() - assert (variational_family.variational_mean == jnp.zeros((n_inducing, 1))).all() - # Test params - params = variational_family._initialise_params(jr.PRNGKey(123)) +def diag_matrix_val(val): + """Diagonal matrix of shape (n_inducing, n_inducing) filled with val""" + + def diag_matrix_fn(n_inducing): + return jnp.eye(n_inducing) * val + + return diag_matrix_fn + + +@pytest.mark.parametrize("n_test", [1, 10]) +@pytest.mark.parametrize("n_inducing", [1, 10, 20]) +@pytest.mark.parametrize( + "variational_family, moment_names, shapes, values", + [ + ( + VariationalGaussian, + ["variational_mean", "variational_root_covariance"], + [vector_shape, matrix_shape], + [vector_val(0.0), diag_matrix_val(1.0)], + ), + ( + WhitenedVariationalGaussian, + ["variational_mean", "variational_root_covariance"], + [vector_shape, matrix_shape], + [vector_val(0.0), diag_matrix_val(1.0)], + ), + ( + NaturalVariationalGaussian, + ["natural_vector", "natural_matrix"], + [vector_shape, matrix_shape], + [vector_val(0.0), diag_matrix_val(-0.5)], + ), + ( + ExpectationVariationalGaussian, + ["expectation_vector", "expectation_matrix"], + [vector_shape, matrix_shape], + [vector_val(0.0), diag_matrix_val(1.0)], + ), + ], +) +def test_variational_gaussians( + n_test, n_inducing, variational_family, moment_names, shapes, values +): + + # Initialise variational family: + prior = gpx.Prior(kernel=gpx.RBF()) + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) + test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) + q = variational_family(prior=prior, inducing_inputs=inducing_inputs) + + # Test init: + assert q.num_inducing == n_inducing + assert isinstance(q, AbstractVariationalFamily) + + # Test params and keys: + params = q._initialise_params(jr.PRNGKey(123)) assert isinstance(params, dict) + + config_params = gpx.config.get_defaults() + + # Test inducing induput parameters: assert "inducing_inputs" in params["variational_family"].keys() - assert "variational_mean" in params["variational_family"]["moments"].keys() - assert ( - "variational_root_covariance" in params["variational_family"]["moments"].keys() - ) + assert "inducing_inputs" in config_params["transformations"].keys() - assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["variational_mean"].shape == ( - n_inducing, - 1, - ) - assert params["variational_family"]["moments"][ - "variational_root_covariance" - ].shape == (n_inducing, n_inducing) + for moment_name, shape, value in zip(moment_names, shapes, values): - assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance( - params["variational_family"]["moments"]["variational_mean"], jnp.DeviceArray - ) - assert isinstance( - params["variational_family"]["moments"]["variational_root_covariance"], - jnp.DeviceArray, - ) + moment_params = params["variational_family"]["moments"] - params = gpx.config.get_defaults() - assert "variational_root_covariance" in params["transformations"].keys() - assert "variational_mean" in params["transformations"].keys() + assert moment_name in moment_params.keys() + assert moment_name in config_params["transformations"].keys() - assert (variational_family.variational_root_covariance == jnp.eye(n_inducing)).all() - assert (variational_family.variational_mean == jnp.zeros((n_inducing, 1))).all() + # Test moment shape and values: + moment = moment_params[moment_name] + assert isinstance(moment, jnp.ndarray) + assert moment.shape == shape(n_inducing) + assert (moment == value(n_inducing)).all() # Test KL - params = variational_family._initialise_params(jr.PRNGKey(123)) - kl = variational_family.prior_kl(params) + params = q._initialise_params(jr.PRNGKey(123)) + kl = q.prior_kl(params) assert isinstance(kl, jnp.ndarray) # Test predictions - predictive_dist_fn = variational_family(params) + predictive_dist_fn = q(params) assert isinstance(predictive_dist_fn, tp.Callable) predictive_dist = predictive_dist_fn(test_inputs) @@ -123,7 +156,7 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) test_inputs = jnp.hstack([test_inputs] * point_dim) - variational_family = gpx.variational_families.CollapsedVariationalGaussian( + variational_family = CollapsedVariationalGaussian( prior=prior, likelihood=gpx.Gaussian(num_datapoints=D.n), inducing_inputs=inducing_inputs, @@ -131,7 +164,7 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ # We should raise an error for non-Gaussian likelihoods: with pytest.raises(TypeError): - gpx.variational_families.CollapsedVariationalGaussian( + CollapsedVariationalGaussian( prior=prior, likelihood=gpx.Bernoulli(num_datapoints=D.n), inducing_inputs=inducing_inputs, @@ -170,171 +203,3 @@ def test_collapsed_variational_gaussian(n_test, n_inducing, n_datapoints, point_ assert isinstance(sigma, jnp.ndarray) assert mu.shape == (n_test,) assert sigma.shape == (n_test, n_test) - - -@pytest.mark.parametrize("n_test", [1, 10]) -@pytest.mark.parametrize("n_inducing", [1, 10, 20]) -def test_natural_variational_gaussian(n_inducing, n_test): - prior = gpx.Prior(kernel=gpx.RBF()) - - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) - test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - - variational_family = gpx.variational_families.NaturalVariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - - # Test init - assert variational_family.num_inducing == n_inducing - - assert jnp.sum(variational_family.natural_vector) == 0.0 - assert variational_family.natural_vector.shape == (n_inducing, 1) - - assert variational_family.natural_matrix.shape == ( - n_inducing, - n_inducing, - ) - assert jnp.all(jnp.diag(variational_family.natural_matrix) == -0.5) - - params = gpx.config.get_defaults() - assert "variational_root_covariance" in params["transformations"].keys() - assert "variational_mean" in params["transformations"].keys() - - assert (variational_family.natural_matrix == -0.5 * jnp.eye(n_inducing)).all() - assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() - - # params - params = variational_family._initialise_params(jr.PRNGKey(123)) - assert isinstance(params, dict) - assert "inducing_inputs" in params["variational_family"].keys() - assert "natural_vector" in params["variational_family"]["moments"].keys() - assert "natural_matrix" in params["variational_family"]["moments"].keys() - - assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["natural_vector"].shape == ( - n_inducing, - 1, - ) - assert params["variational_family"]["moments"]["natural_matrix"].shape == ( - n_inducing, - n_inducing, - ) - - assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance( - params["variational_family"]["moments"]["natural_vector"], jnp.DeviceArray - ) - assert isinstance( - params["variational_family"]["moments"]["natural_matrix"], jnp.DeviceArray - ) - - params = gpx.config.get_defaults() - assert "natural_vector" in params["transformations"].keys() - assert "natural_matrix" in params["transformations"].keys() - - assert (variational_family.natural_matrix == -0.5 * jnp.eye(n_inducing)).all() - assert (variational_family.natural_vector == jnp.zeros((n_inducing, 1))).all() - - # Test KL - params = variational_family._initialise_params(jr.PRNGKey(123)) - kl = variational_family.prior_kl(params) - assert isinstance(kl, jnp.ndarray) - - # Test predictions - predictive_dist_fn = variational_family(params) - assert isinstance(predictive_dist_fn, tp.Callable) - - predictive_dist = predictive_dist_fn(test_inputs) - assert isinstance(predictive_dist, dx.Distribution) - - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() - - assert isinstance(mu, jnp.ndarray) - assert isinstance(sigma, jnp.ndarray) - assert mu.shape == (n_test,) - assert sigma.shape == (n_test, n_test) - - -@pytest.mark.parametrize("n_test", [1, 10]) -@pytest.mark.parametrize("n_inducing", [1, 10, 20]) -def test_expectation_variational_gaussian(n_inducing, n_test): - prior = gpx.Prior(kernel=gpx.RBF()) - - inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing).reshape(-1, 1) - test_inputs = jnp.linspace(-5.0, 5.0, n_test).reshape(-1, 1) - - variational_family = gpx.variational_families.ExpectationVariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs - ) - - # Test init - assert variational_family.num_inducing == n_inducing - - assert jnp.sum(variational_family.expectation_vector) == 0.0 - assert variational_family.expectation_vector.shape == (n_inducing, 1) - - assert variational_family.expectation_matrix.shape == ( - n_inducing, - n_inducing, - ) - assert jnp.all(jnp.diag(variational_family.expectation_matrix) == 1.0) - - params = gpx.config.get_defaults() - assert "variational_root_covariance" in params["transformations"].keys() - assert "variational_mean" in params["transformations"].keys() - - assert (variational_family.expectation_matrix == jnp.eye(n_inducing)).all() - assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() - - # params - params = variational_family._initialise_params(jr.PRNGKey(123)) - assert isinstance(params, dict) - assert "inducing_inputs" in params["variational_family"].keys() - assert "expectation_vector" in params["variational_family"]["moments"].keys() - assert "expectation_matrix" in params["variational_family"]["moments"].keys() - - assert params["variational_family"]["inducing_inputs"].shape == (n_inducing, 1) - assert params["variational_family"]["moments"]["expectation_vector"].shape == ( - n_inducing, - 1, - ) - assert params["variational_family"]["moments"]["expectation_matrix"].shape == ( - n_inducing, - n_inducing, - ) - - assert isinstance(params["variational_family"]["inducing_inputs"], jnp.DeviceArray) - assert isinstance( - params["variational_family"]["moments"]["expectation_vector"], jnp.DeviceArray - ) - assert isinstance( - params["variational_family"]["moments"]["expectation_matrix"], jnp.DeviceArray - ) - - params = gpx.config.get_defaults() - assert "expectation_vector" in params["transformations"].keys() - assert "expectation_matrix" in params["transformations"].keys() - - assert (variational_family.expectation_matrix == jnp.eye(n_inducing)).all() - assert (variational_family.expectation_vector == jnp.zeros((n_inducing, 1))).all() - - # Test KL - params = variational_family._initialise_params(jr.PRNGKey(123)) - kl = variational_family.prior_kl(params) - assert isinstance(kl, jnp.ndarray) - - # Test predictions - predictive_dist_fn = variational_family(params) - assert isinstance(predictive_dist_fn, tp.Callable) - - predictive_dist = predictive_dist_fn(test_inputs) - assert isinstance(predictive_dist, dx.Distribution) - - mu = predictive_dist.mean() - sigma = predictive_dist.covariance() - - assert isinstance(mu, jnp.ndarray) - assert isinstance(sigma, jnp.ndarray) - assert mu.shape == (n_test,) - assert sigma.shape == (n_test, n_test) diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py index 732e1851..c55ec4f5 100644 --- a/tests/test_variational_inference.py +++ b/tests/test_variational_inference.py @@ -6,6 +6,14 @@ import pytest import gpjax as gpx +from gpjax import variational_inference +from gpjax.variational_families import ( + CollapsedVariationalGaussian, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, + VariationalGaussian, + WhitenedVariationalGaussian, +) def test_abstract_variational_inference(): @@ -37,26 +45,25 @@ def get_data_and_gp(n_datapoints, point_dim): @pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) -@pytest.mark.parametrize("n_test", [1, 10]) -@pytest.mark.parametrize("whiten", [True, False]) -@pytest.mark.parametrize("diag", [True, False]) @pytest.mark.parametrize("jit_fns", [False, True]) -@pytest.mark.parametrize("point_dim", [1, 2]) +@pytest.mark.parametrize("point_dim", [1, 2, 3]) +@pytest.mark.parametrize( + "variational_family", + [ + VariationalGaussian, + WhitenedVariationalGaussian, + NaturalVariationalGaussian, + ExpectationVariationalGaussian, + ], +) def test_stochastic_vi( - n_datapoints, n_inducing_points, n_test, whiten, diag, jit_fns, point_dim + n_datapoints, n_inducing_points, jit_fns, point_dim, variational_family ): D, post, prior = get_data_and_gp(n_datapoints, point_dim) inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - if whiten is True: - q = gpx.WhitenedVariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs, diag=diag - ) - else: - q = gpx.VariationalGaussian( - prior=prior, inducing_inputs=inducing_inputs, diag=diag - ) + q = variational_family(prior=prior, inducing_inputs=inducing_inputs) svgp = gpx.StochasticVI(posterior=post, variational_family=q) assert svgp.posterior.prior == post.prior @@ -92,7 +99,7 @@ def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) - q = gpx.variational_families.CollapsedVariationalGaussian( + q = CollapsedVariationalGaussian( prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs ) From c01f733b535d2a423cb219e3f01540d9bd72730f Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 23 Aug 2022 17:57:57 +0100 Subject: [PATCH 55/66] Update typing. --- gpjax/abstractions.py | 14 +++---- gpjax/gps.py | 69 +++++++++++++++++----------------- gpjax/kernels.py | 38 ++++++++++--------- gpjax/likelihoods.py | 13 ++++--- gpjax/mean_functions.py | 12 +++--- gpjax/natural_gradients.py | 48 ++++++++++++------------ gpjax/parameters.py | 70 +++++++++++++++++------------------ gpjax/types.py | 2 +- gpjax/utils.py | 18 ++++----- gpjax/variational_families.py | 14 +++---- 10 files changed, 151 insertions(+), 147 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 0a9522c5..0871ef73 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -1,4 +1,4 @@ -import typing as tp +from typing import Callable, Dict, Optional import jax import jax.numpy as jnp @@ -97,7 +97,7 @@ def wrapper_progress_bar(carry, x): def fit( - objective: tp.Callable, + objective: Callable, parameter_state: ParameterState, optax_optim, n_iters: int = 100, @@ -106,7 +106,7 @@ def fit( """Abstracted method for fitting a GP model with respect to a supplied objective function. Optimisers used here should originate from Optax. Args: - objective (tp.Callable): The objective function that we are optimising with respect to. + objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The initial parameter state. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. @@ -149,19 +149,19 @@ def step(carry, iter_num): def fit_batches( - objective: tp.Callable, + objective: Callable, parameter_state: ParameterState, train_data: Dataset, optax_optim, key: PRNGKeyType, batch_size: int, - n_iters: tp.Optional[int] = 100, - log_rate: tp.Optional[int] = 10, + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, ) -> InferenceState: """Abstracted method for fitting a GP model with mini-batches respect to a supplied objective function. Optimisers used here should originate from Optax. Args: - objective (tp.Callable): The objective function that we are optimising with respect to. + objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. train_data (Dataset): The training dataset. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. diff --git a/gpjax/gps.py b/gpjax/gps.py index b52cd244..a8a748ae 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -1,6 +1,5 @@ -import typing as tp -from abc import abstractmethod, abstractproperty -from typing import Dict +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional import distrax as dx import jax.numpy as jnp @@ -20,8 +19,8 @@ NonConjugateLikelihoodType, ) from .mean_functions import AbstractMeanFunction, Zero -from .parameters import copy_dict_structure, evaluate_priors -from .types import Dataset +from .parameters import copy_dict_structure, evaluate_priors, transform +from .types import Dataset, PRNGKeyType from .utils import I, concat_dictionaries DEFAULT_JITTER = get_defaults()["jitter"] @@ -31,7 +30,7 @@ class AbstractGP: """Abstract Gaussian process object.""" - def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the Gaussian process at the given points. Returns: @@ -40,11 +39,11 @@ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Compute the latent function's multivariate normal distribution.""" raise NotImplementedError - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the GP's parameter set""" raise NotImplementedError @@ -57,9 +56,9 @@ class Prior(AbstractGP): """A Gaussian process prior object. The GP is parameterised by a mean and kernel function.""" kernel: Kernel - mean_function: tp.Optional[AbstractMeanFunction] = Zero() - name: tp.Optional[str] = "GP prior" - jitter: tp.Optional[float] = DEFAULT_JITTER + mean_function: Optional[AbstractMeanFunction] = Zero() + name: Optional[str] = "GP prior" + jitter: Optional[float] = DEFAULT_JITTER def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. @@ -79,9 +78,9 @@ def predict( ) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the GP's prior mean and variance. Args: - params (dict): The specific set of parameters for which the mean function should be defined for. + params (Dict): The specific set of parameters for which the mean function should be defined for. Returns: - tp.Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. + Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. """ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: @@ -97,7 +96,7 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: return predict_fn - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the GP prior's parameter set""" return { "kernel": self.kernel._initialise_params(key), @@ -114,15 +113,15 @@ class AbstractPosterior(AbstractGP): prior: Prior likelihood: AbstractLikelihood - name: tp.Optional[str] = "GP posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "GP posterior" + jitter: Optional[float] = DEFAULT_JITTER @abstractmethod - def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Predict the GP's output given the input.""" raise NotImplementedError - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a GP posterior.""" return concat_dictionaries( self.prior._initialise_params(key), @@ -136,8 +135,8 @@ class ConjugatePosterior(AbstractPosterior): prior: Prior likelihood: Gaussian - name: tp.Optional[str] = "Conjugate posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "Conjugate posterior" + jitter: Optional[float] = DEFAULT_JITTER def predict( self, train_data: Dataset, params: dict @@ -146,10 +145,10 @@ def predict( Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (dict): A dictionary of parameters that should be used to compute the posterior. + params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -194,23 +193,23 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: def marginal_log_likelihood( self, train_data: Dataset, - priors: dict = None, + priors: Dict = None, negative: bool = False, ) -> tp.Callable[[dict], Float[Array, "1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values. Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. + priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n def mll( - params: dict, + params: Dict, ): # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] @@ -245,10 +244,10 @@ class NonConjugatePosterior(AbstractPosterior): prior: Prior likelihood: NonConjugateLikelihoodType - name: tp.Optional[str] = "Non-conjugate posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "Non-conjugate posterior" + jitter: Optional[float] = DEFAULT_JITTER - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior.""" parameters = concat_dictionaries( self.prior._initialise_params(key), @@ -264,10 +263,10 @@ def predict( Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (dict): A dictionary of parameters that should be used to compute the posterior. + params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, n = train_data.X, train_data.n @@ -301,18 +300,18 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: def marginal_log_likelihood( self, train_data: Dataset, - priors: dict = None, + priors: Dict = None, negative: bool = False, ) -> tp.Callable[[dict], Float[Array, "1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax. Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. + priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -320,7 +319,7 @@ def marginal_log_likelihood( priors = copy_dict_structure(self._initialise_params(jr.PRNGKey(0))) priors["latent"] = dx.Normal(loc=0.0, scale=1.0) - def mll(params: dict): + def mll(params: Dict): Kxx = gram(self.prior.kernel, x, params["kernel"]) Kxx += I(n) * self.jitter Lx = jnp.linalg.cholesky(Kxx) diff --git a/gpjax/kernels.py b/gpjax/kernels.py index d88ab25e..5636da84 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -6,6 +6,8 @@ from jax import vmap from jaxtyping import Array, Float +from .types import PRNGKeyType + ########################################## # Abtract classes @@ -30,7 +32,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)`. """ @@ -61,7 +63,7 @@ def ard(self): return True if self.ndims > 1 else False @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """A template dictionary of the kernel's parameter set.""" raise NotImplementedError @@ -99,7 +101,7 @@ def _set_kernels(self, kernels: Sequence[Kernel]) -> None: kernels_list.append(k) self.kernel_set = kernels_list - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """A template dictionary of the kernel's parameter set.""" return [kernel._initialise_params(key) for kernel in self.kernel_set] @@ -150,7 +152,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -160,7 +162,7 @@ def __call__( K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -187,7 +189,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` """ @@ -196,7 +198,7 @@ def __call__( K = params["variance"] * jnp.exp(-0.5 * euclidean_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -223,7 +225,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -238,7 +240,7 @@ def __call__( ) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -265,7 +267,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -280,7 +282,7 @@ def __call__( ) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -309,7 +311,7 @@ def __call__( Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -319,7 +321,7 @@ def __call__( K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "shift": jnp.array([1.0]), "variance": jnp.array([1.0] * self.ndims), @@ -352,7 +354,7 @@ def __call__( Args: x (jnp.DeviceArray): Index of the ith vertex y (jnp.DeviceArray): Index of the jth vertex - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of k(v_i, v_j). @@ -369,7 +371,7 @@ def __call__( ) return kxy.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -399,7 +401,7 @@ def gram( Args: kernel (Kernel): The kernel for which the Gram matrix should be computed for. inputs (Array): The input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed square Gram matrix. @@ -416,7 +418,7 @@ def cross_covariance( kernel (Kernel): The kernel for which the cross-covariance matrix should be computed for. x (Array): The first input matrix. y (Array): The second input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed square Gram matrix. @@ -431,7 +433,7 @@ def diagonal( Args: kernel (Kernel): The kernel for which the variance vector should be computed for. inputs (Array): The input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed diagonal variance matrix. """ diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 2793126a..71ae16d0 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -7,6 +7,7 @@ from chex import dataclass from jaxtyping import Array, Float +from .types import PRNGKeyType from .utils import I @@ -27,7 +28,7 @@ def predict(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the parameters of the likelihood function.""" raise NotImplementedError @@ -54,7 +55,7 @@ class Gaussian(AbstractLikelihood, Conjugate): name: Optional[str] = "Gaussian" - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the variance parameter of the likelihood function.""" return {"obs_noise": jnp.array([1.0])} @@ -66,12 +67,12 @@ def link_function(self) -> Callable: Callable: A link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: dict) -> dx.Distribution: + def link_fn(x, params: Dict) -> dx.Distribution: return dx.Normal(loc=x, scale=params["obs_noise"]) return link_fn - def predict(self, dist: dx.Distribution, params: dict) -> dx.Distribution: + def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix..""" n_data = dist.event_shape[0] noisy_cov = dist.covariance() + I(n_data) * params["likelihood"]["obs_noise"] @@ -82,7 +83,7 @@ def predict(self, dist: dx.Distribution, params: dict) -> dx.Distribution: class Bernoulli(AbstractLikelihood, NonConjugate): name: Optional[str] = "Bernoulli" - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a Bernoulli likelihood.""" return {} @@ -115,7 +116,7 @@ def moment_fn( return moment_fn - def predict(self, dist: dx.Distribution, params: dict) -> Any: + def predict(self, dist: dx.Distribution, params: Dict) -> Any: variance = jnp.diag(dist.covariance()) mean = dist.mean() return self.predictive_moment_fn(mean.ravel(), variance, params) diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 4a9df40f..960bface 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -5,6 +5,8 @@ from chex import dataclass from jaxtyping import Array, Float +from .types import PRNGKeyType + @dataclass(repr=False) class AbstractMeanFunction: @@ -26,11 +28,11 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. Returns: - dict: The parameters of the mean function. + Dict: The parameters of the mean function. """ raise NotImplementedError @@ -49,7 +51,7 @@ def __call__(self, x: Float[Array, "N D"], params: dict) -> Float[Array, "N Q"]: Args: x (Array): The input points at which to evaluate the mean function. - params (dict): The parameters of the mean function. + params (Dict): The parameters of the mean function. Returns: Array: A vector of zeros. @@ -57,7 +59,7 @@ def __call__(self, x: Float[Array, "N D"], params: dict) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """The parameters of the mean function. For the zero-mean function, this is an empty dictionary.""" return {} @@ -85,6 +87,6 @@ def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value.""" return {"constant": jnp.array([1.0])} diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 1f68ac0a..f67dc52b 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,5 +1,5 @@ -import typing as tp from copy import deepcopy +from typing import Callable, Dict, Optional, Tuple import jax import jax.numpy as jnp @@ -32,15 +32,15 @@ def natural_to_expectation( - natural_moments: dict, jitter: float = DEFAULT_JITTER -) -> dict: + natural_moments: Dict, jitter: float = DEFAULT_JITTER +) -> Dict: """ Converts natural parameters to expectation parameters. Args: natural_moments: A dictionary of natural parameters. jitter (float): A small value to prevent numerical instability. Returns: - tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. + Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ natural_matrix = natural_moments["natural_matrix"] @@ -79,7 +79,7 @@ def _expectation_elbo( posterior: AbstractPosterior, variational_family: AbstractVariationalFamily, train_data: Dataset, -) -> tp.Callable[[dict, Dataset], float]: +) -> Callable[[Dict, Dataset], float]: """ Construct evidence lower bound (ELBO) for variational Gaussian under the expectation parameterisation. Args: @@ -100,13 +100,13 @@ def _expectation_elbo( return svgp.elbo(train_data, identity_transformation, negative=True) -def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: +def _stop_gradients_nonmoments(params: Dict) -> Dict: """ Stops gradients for non-moment parameters. Args: params: A dictionary of parameters. Returns: - tp.Dict: A dictionary of parameters with stopped gradients. + Dict: A dictionary of parameters with stopped gradients. """ trainables = build_trainables_false(params) moment_trainables = build_trainables_true(params["variational_family"]["moments"]) @@ -115,13 +115,13 @@ def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: return params -def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: +def _stop_gradients_moments(params: Dict) -> Dict: """ Stops gradients for moment parameters. Args: params: A dictionary of parameters. Returns: - tp.Dict: A dictionary of parameters with stopped gradients. + Dict: A dictionary of parameters with stopped gradients. """ trainables = build_trainables_true(params) moment_trainables = build_trainables_false(params["variational_family"]["moments"]) @@ -133,8 +133,8 @@ def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, - transformations: dict, -) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: + transformations: Dict, +) -> Tuple[Callable[[Dict, Dataset], Dict]]: """ Computes the gradient with respect to the natural parameters. Currently only implemented for the natural variational Gaussian family. Args: @@ -143,7 +143,7 @@ def natural_gradients( train_data: A Dataset. transformations: A dictionary of transformations. Returns: - Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. + Tuple[Callable[[Dict, Dataset], Dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. """ posterior = stochastic_vi.posterior variational_family = stochastic_vi.variational_family @@ -156,7 +156,7 @@ def natural_gradients( if isinstance(variational_family, NaturalVariationalGaussian): - def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: """ Computes the natural gradients of the ELBO. Args: @@ -164,7 +164,7 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: trainables: A dictionary of trainables. batch: A Dataset. Returns: - dict: A dictionary of natural gradients. + Dict: A dictionary of natural gradients. """ # Transform parameters to constrained space. params = transform(params, transformations) @@ -180,7 +180,7 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: expectation_params["variational_family"]["moments"] = expectation_moments # Compute gradient ∂L/∂η: - def loss_fn(params: dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. trains = deepcopy(trainables) trains["variational_family"]["moments"] = build_trainables_true( @@ -211,7 +211,7 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: else: raise NotImplementedError - def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + def hyper_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: """ Computes the hyperparameter gradients of the ELBO. Args: @@ -219,10 +219,10 @@ def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: trainables: A dictionary of trainables. batch: A Dataset. Returns: - dict: A dictionary of hyperparameter gradients. + Dict: A dictionary of hyperparameter gradients. """ - def loss_fn(params: dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. params = trainable_params(params, trainables) @@ -240,17 +240,17 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: def fit_natgrads( stochastic_vi: StochasticVI, - params: tp.Dict, - trainables: tp.Dict, - transformations: tp.Dict, + params: Dict, + trainables: Dict, + transformations: Dict, train_data: Dataset, batch_size: int, moment_optim, hyper_optim, key: PRNGKeyType, - n_iters: tp.Optional[int] = 100, - log_rate: tp.Optional[int] = 10, -) -> tp.Dict: + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, +) -> Dict: hyper_state = hyper_optim.init(params) moment_state = moment_optim.init(params) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 6f1bd28c..95359592 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -1,6 +1,6 @@ -import typing as tp import warnings from copy import deepcopy +from typing import Dict, Tuple from warnings import warn import distrax as dx @@ -24,9 +24,9 @@ class ParameterState: """The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained.""" - params: tp.Dict - trainables: tp.Dict - bijectors: tp.Dict + params: Dict + trainables: Dict + bijectors: Dict def unpack(self): return self.params, self.trainables, self.bijectors @@ -58,7 +58,7 @@ def _validate_kwargs(kwargs, params): raise ValueError(f"Parameter {k} is not a valid parameter.") -def recursive_items(d1: tp.Dict, d2: tp.Dict): +def recursive_items(d1: Dict, d2: Dict): """Recursive loop over pair of dictionaries whereby the value of a given key in either dictionary can be itself a dictionary. Args: @@ -75,15 +75,15 @@ def recursive_items(d1: tp.Dict, d2: tp.Dict): yield (key, value, d2[key]) -def recursive_complete(d1: tp.Dict, d2: tp.Dict) -> tp.Dict: +def recursive_complete(d1: Dict, d2: Dict) -> Dict: """Recursive loop over pair of dictionaries whereby the value of a given key in either dictionary can be itself a dictionary. If the value of the key in the second dictionary is None, the value of the key in the first dictionary is used. Args: - d1 (tp.Dict): The reference dictionary. - d2 (tp.Dict): The potentially incomplete dictionary. + d1 (Dict): The reference dictionary. + d2 (Dict): The potentially incomplete dictionary. Returns: - tp.Dict: A completed form of the second dictionary. + Dict: A completed form of the second dictionary. """ for key, value in d1.items(): if type(value) is dict: @@ -98,14 +98,14 @@ def recursive_complete(d1: tp.Dict, d2: tp.Dict) -> tp.Dict: ################################ # Parameter transformation ################################ -def build_bijectors(params: tp.Dict) -> tp.Dict: +def build_bijectors(params: Dict) -> Dict: """For each parameter, build the bijection pair that allows the parameter to be constrained and unconstrained. Args: - params (tp.Dict): _description_ + params (Dict): _description_ Returns: - tp.Dict: A dictionary that maps each parameter to a bijection. + Dict: A dictionary that maps each parameter to a bijection. """ bijectors = copy_dict_structure(params) config = get_defaults() @@ -114,7 +114,7 @@ def build_bijectors(params: tp.Dict) -> tp.Dict: def recursive_bijectors_list(ps, bs): return [recursive_bijectors(ps[i], bs[i]) for i in range(len(bs))] - def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: + def recursive_bijectors(ps, bs) -> Tuple[Dict, Dict]: if type(ps) is list: bs = recursive_bijectors_list(ps, bs) @@ -139,7 +139,7 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: +def constrain(params: Dict, bijectors: Dict) -> Dict: """Transform the parameters to the constrained space for corresponding bijectors. Args: @@ -156,7 +156,7 @@ def constrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) -def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: +def unconstrain(params: Dict, bijectors: Dict) -> Dict: """Transform the parameters to the unconstrained space for corresponding bijectors. Args: @@ -173,13 +173,13 @@ def unconstrain(params: tp.Dict, bijectors: tp.Dict) -> tp.Dict: return jax.tree_util.tree_map(map, params, bijectors) -def build_identity(params: tp.Dict) -> tp.Dict: +def build_identity(params: Dict) -> Dict: """ " Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -200,7 +200,7 @@ def log_density( return log_prob -def copy_dict_structure(params: dict) -> dict: +def copy_dict_structure(params: Dict) -> Dict: # Copy dictionary structure prior_container = deepcopy(params) # Set all values to zero @@ -208,15 +208,15 @@ def copy_dict_structure(params: dict) -> dict: return prior_container -def structure_priors(params: dict, priors: dict) -> dict: +def structure_priors(params: Dict, priors: Dict) -> Dict: """First create a dictionary with equal structure to the parameters. Then, for each supplied prior, overwrite the None value if it exists. Args: - params (dict): [description] - priors (dict): [description] + params (Dict): [description] + priors (Dict): [description] Returns: - dict: [description] + Dict: [description] """ prior_container = copy_dict_structure(params) # Where a prior has been supplied, override the None value by the prior distribution. @@ -224,14 +224,14 @@ def structure_priors(params: dict, priors: dict) -> dict: return complete_prior -def evaluate_priors(params: dict, priors: dict) -> dict: +def evaluate_priors(params: Dict, priors: Dict) -> Dict: """Recursive loop over pair of dictionaries that correspond to a parameter's current value and the parameter's respective prior distribution. For parameters where a prior distribution is specified, the log-prior density is evaluated at the parameter's current value. - Args: params (dict): Dictionary containing the current set of parameter - estimates. priors (dict): Dictionary specifying the parameters' prior + Args: params (Dict): Dictionary containing the current set of parameter + estimates. priors (Dict): Dictionary specifying the parameters' prior distributions. Returns: Array: The log-prior density, summed over all parameters. @@ -243,7 +243,7 @@ def evaluate_priors(params: dict, priors: dict) -> dict: return lpd -def prior_checks(priors: dict) -> dict: +def prior_checks(priors: Dict) -> Dict: """Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian.""" if "latent" in priors.keys(): latent_prior = priors["latent"] @@ -263,14 +263,14 @@ def prior_checks(priors: dict) -> dict: return priors -def build_trainables_true(params: tp.Dict) -> tp.Dict: +def build_trainables_true(params: Dict) -> Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is trainable. Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -279,14 +279,14 @@ def build_trainables_true(params: tp.Dict) -> tp.Dict: return prior_container -def build_trainables_false(params: tp.Dict) -> tp.Dict: +def build_trainables_false(params: Dict) -> Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is NOT trainable. Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -295,12 +295,12 @@ def build_trainables_false(params: tp.Dict) -> tp.Dict: return prior_container -def stop_grad(param: tp.Dict, trainable: tp.Dict): +def stop_grad(param: Dict, trainable: Dict): """When taking a gradient, we want to stop the gradient from flowing through a parameter if it is not trainable. This is achieved using the model's dictionary of parameters and the corresponding trainability status.""" return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: +def trainable_params(params: Dict, trainables: Dict) -> Dict: """Stop the gradients flowing through parameters whose trainable status is False""" return jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables diff --git a/gpjax/types.py b/gpjax/types.py index 946efefc..9cae03ca 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -1,9 +1,9 @@ import jax.numpy as jnp -from chex import PRNGKey as PRNGKeyType from chex import dataclass from jaxtyping import Array, Float NoneType = type(None) +from chex import PRNGKey as PRNGKeyType @dataclass diff --git a/gpjax/utils.py b/gpjax/utils.py index 2d8861e6..0b1c0055 100644 --- a/gpjax/utils.py +++ b/gpjax/utils.py @@ -1,5 +1,5 @@ -import typing as tp from copy import deepcopy +from typing import Callable, Dict, Tuple import jax import jax.numpy as jnp @@ -17,7 +17,7 @@ def I(n: int) -> Float[Array, "N N"]: return jnp.eye(n) -def concat_dictionaries(a: dict, b: dict) -> dict: +def concat_dictionaries(a: Dict, b: Dict) -> Dict: """ Append one dictionary below another. If duplicate keys exist, then the key-value pair of the second supplied dictionary will be used. @@ -25,7 +25,7 @@ def concat_dictionaries(a: dict, b: dict) -> dict: return {**a, **b} -def merge_dictionaries(base_dict: dict, in_dict: dict) -> dict: +def merge_dictionaries(base_dict: Dict, in_dict: Dict) -> Dict: """ This will return a complete dictionary based on the keys of the first matrix. If the same key should exist in the second matrix, then the key-value pair from the first dictionary will be overwritten. The purpose of this is that @@ -42,7 +42,7 @@ def merge_dictionaries(base_dict: dict, in_dict: dict) -> dict: return base_dict -def sort_dictionary(base_dict: dict) -> dict: +def sort_dictionary(base_dict: Dict) -> Dict: """ Sort a dictionary based on the dictionary's key values. @@ -52,7 +52,7 @@ def sort_dictionary(base_dict: dict) -> dict: return dict(sorted(base_dict.items())) -def as_constant(parameter_set: dict, params: list) -> tp.Tuple[dict, dict]: +def as_constant(parameter_set: Dict, params: list) -> Tuple[Dict, Dict]: base_params = deepcopy(parameter_set) sparams = {} for param in params: @@ -61,21 +61,21 @@ def as_constant(parameter_set: dict, params: list) -> tp.Tuple[dict, dict]: return base_params, sparams -def dict_array_coercion(params: tp.Dict) -> tp.Tuple[tp.Callable, tp.Callable]: +def dict_array_coercion(params: Dict) -> Tuple[Callable, Callable]: """Construct the logic required to map a dictionary of parameters to an array of parameters. The values of the dictionary can themselves be dictionaries; the function should work recursively. Args: - params (tp.Dict): The dictionary of parameters that we would like to map into an array. + params (Dict): The dictionary of parameters that we would like to map into an array. Returns: - tp.Tuple[tp.Callable, tp.Callable]: A pair of functions, the first of which maps a dictionary to an array, and the second of which maps an array to a dictionary. The remapped dictionary is equal in structure to the original dictionary. + Tuple[Callable, Callable]: A pair of functions, the first of which maps a dictionary to an array, and the second of which maps an array to a dictionary. The remapped dictionary is equal in structure to the original dictionary. """ flattened_pytree = jax.tree_util.tree_flatten(params) def dict_to_array(parameter_dict) -> jnp.DeviceArray: return jax.tree_util.tree_flatten(parameter_dict)[0] - def array_to_dict(parameter_array) -> tp.Dict: + def array_to_dict(parameter_array) -> Dict: return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array) return dict_to_array, array_to_dict diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 1120241f..01ee3e2e 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -110,7 +110,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ]. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -198,7 +198,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi N[f(t); μt + Ktz Lz⁻ᵀ μ, Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻ᵀ S Lz⁻¹ Kzt]. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -302,11 +302,11 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -428,11 +428,11 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -529,7 +529,7 @@ def predict( ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ From 1cf85144d5c6624e213c3498d7f8ea734d0590ea Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 23 Aug 2022 18:38:01 +0100 Subject: [PATCH 56/66] Address review comments. --- gpjax/natural_gradients.py | 15 ++++++++++++++- gpjax/parameters.py | 4 ++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index f67dc52b..ccc9318d 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -35,7 +35,20 @@ def natural_to_expectation( natural_moments: Dict, jitter: float = DEFAULT_JITTER ) -> Dict: """ - Converts natural parameters to expectation parameters. + Translate natural parameters to expectation parameters. + + In particular, in terms of the Gaussian mean μ and covariance matrix μ for the Gaussian variational family, + + - the natural parameteristaion is θ = (S⁻¹μ, -S⁻¹/2) + - the expectation parameters are η = (μ, S + μ μᵀ). + + This function solves these eqautions in terms of μ and S to convert θ to η. + + Writing θ = (θ₁, θ₂), we have that S⁻¹ = -2θ₂ . Taking the cholesky decomposition of the inverse covariance, + S⁻¹ = LLᵀ and defining C = L⁻¹, we have S = CᵀC and μ = Sθ₁ = CᵀC θ₁. + + Now from here, using μ and S found from θ, we compute η as η₁ = μ, and η₂ = S + μ μᵀ. + Args: natural_moments: A dictionary of natural parameters. jitter (float): A small value to prevent numerical instability. diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 95359592..b451d159 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -184,7 +184,7 @@ def build_identity(params: Dict) -> Dict: # Copy dictionary structure prior_container = deepcopy(params) - return jax.tree_map(lambda _: Identity.forward, prior_container) + return jax.tree_util.tree_map(lambda _: Identity.forward, prior_container) ################################ @@ -291,7 +291,7 @@ def build_trainables_false(params: Dict) -> Dict: # Copy dictionary structure prior_container = deepcopy(params) # Set all values to zero - prior_container = jax.tree_map(lambda _: False, prior_container) + prior_container = jax.tree_util.tree_map(lambda _: False, prior_container) return prior_container From 4833648c89b692d35f1cab17030d065f0f7401d0 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 24 Aug 2022 14:50:19 +0100 Subject: [PATCH 57/66] Address review (except notebook).. --- examples/natgrads.ipynb | 4 +- gpjax/__init__.py | 6 +-- gpjax/abstractions.py | 75 ++++++++++++++++++++++++++++++++++++++ gpjax/natural_gradients.py | 58 +---------------------------- tests/test_abstractions.py | 47 ++++++++++++++++++++++++ 5 files changed, 128 insertions(+), 62 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 19e6158c..e6315c51 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -175,9 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "from gpjax.natural_gradients import fit_natgrads\n", - "\n", - "learned_params, training_history = fit_natgrads(svgp,\n", + "learned_params, training_history = gpx.fit_natgrads(svgp,\n", " params = params,\n", " trainables = trainables, \n", " transformations = constrainers,\n", diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 807d7701..5927d38a 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -5,7 +5,7 @@ # Highlight any potentially unintended broadcasting rank promoting ops. # config.update("jax_numpy_rank_promotion", "warn") -from .abstractions import fit, fit_batches +from .abstractions import fit, fit_batches, fit_natgrads from .gps import Prior, construct_posterior from .kernels import ( RBF, @@ -23,10 +23,10 @@ from .types import Dataset from .variational_families import ( CollapsedVariationalGaussian, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, VariationalGaussian, WhitenedVariationalGaussian, - NaturalVariationalGaussian, - ExpectationVariationalGaussian, ) from .variational_inference import CollapsedVI, StochasticVI diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 0871ef73..34717b10 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -13,6 +13,7 @@ from .parameters import ParameterState, constrain, trainable_params, unconstrain from .parameters import trainable_params, transform from .types import Dataset, PRNGKeyType +from .variational_inference import StochasticVI @dataclass(frozen=True) @@ -208,6 +209,80 @@ def step(carry, iter_num__and__key): return inf_state +def fit_natgrads( + stochastic_vi: StochasticVI, + params: Dict, + trainables: Dict, + transformations: Dict, + train_data: Dataset, + moment_optim, + hyper_optim, + key: PRNGKeyType, + batch_size: int, + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, +) -> Dict: + """This is a training loop for natural gradients. See Salimbeni et al. (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models + + We begin with an initalise natural gradient step to tighten the ELBO for hyperparameter optimisation. There after, each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior. + + Args: + stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training. + params (Dict): The parameters for which we would like to minimise our objective function with. + trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. + transformations (Dict): The transformations to be applied to the parameters. + train_data (Dataset): The training dataset. + batch_size(int): The batch_size. + key (PRNGKeyType): The PRNG key for the mini-batch sampling. + n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. + log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. + Returns: + InferenceState: A dataclass comprising optimised parameters and training history. + """ + + hyper_state = hyper_optim.init(params) + moment_state = moment_optim.init(params) + + nat_grads_fn, hyper_grads_fn = natural_gradients( + stochastic_vi, train_data, transformations + ) + + # Initial natural gradient step to improve bound for hyperparameters: + batch = get_batch(train_data, batch_size, key) + loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) + updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) + params = optax.apply_updates(params, updates) + + keys = jax.random.split(key, n_iters) + iter_nums = jnp.arange(n_iters) + + @progress_bar_scan(n_iters, log_rate) + def step(carry, iter_num__and__key): + iter_num, key = iter_num__and__key + params, hyper_state, moment_state = carry + + batch = get_batch(train_data, batch_size, key) + + # Hyper-parameters update: + loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) + updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) + params = optax.apply_updates(params, updates) + + # Natural gradients update: + loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) + updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) + params = optax.apply_updates(params, updates) + + carry = params, hyper_state, moment_state + return carry, loss_val + + (params, _, _), history = jax.lax.scan( + step, (params, hyper_state, moment_state), (iter_nums, keys) + ) + inf_state = InferenceState(params=params, history=history) + return inf_state + + def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: """Batch the data into mini-batches. Args: diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index ccc9318d..1b75b01b 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,15 +1,12 @@ from copy import deepcopy -from typing import Callable, Dict, Optional, Tuple +from typing import Callable, Dict, Tuple -import jax import jax.numpy as jnp import jax.random as jr import jax.scipy as jsp -import optax as ox from jax import value_and_grad from jaxtyping import f64 -from .abstractions import InferenceState, get_batch, progress_bar_scan from .config import get_defaults from .gps import AbstractPosterior from .parameters import ( @@ -19,7 +16,7 @@ trainable_params, transform, ) -from .types import Dataset, PRNGKeyType +from .types import Dataset from .utils import I from .variational_families import ( AbstractVariationalFamily, @@ -249,54 +246,3 @@ def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: return value, dL_dhyper return nat_grads_fn, hyper_grads_fn - - -def fit_natgrads( - stochastic_vi: StochasticVI, - params: Dict, - trainables: Dict, - transformations: Dict, - train_data: Dataset, - batch_size: int, - moment_optim, - hyper_optim, - key: PRNGKeyType, - n_iters: Optional[int] = 100, - log_rate: Optional[int] = 10, -) -> Dict: - - hyper_state = hyper_optim.init(params) - moment_state = moment_optim.init(params) - - nat_grads_fn, hyper_grads_fn = natural_gradients( - stochastic_vi, train_data, transformations - ) - - keys = jax.random.split(key, n_iters) - iter_nums = jnp.arange(n_iters) - - @progress_bar_scan(n_iters, log_rate) - def step(carry, iter_num__and__key): - iter_num, key = iter_num__and__key - params, hyper_state, moment_state = carry - - batch = get_batch(train_data, batch_size, key) - - # Hyper-parameters update: - loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) - updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) - params = ox.apply_updates(params, updates) - - # Natural gradients update: - loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) - updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) - params = ox.apply_updates(params, updates) - - carry = params, hyper_state, moment_state - return carry, loss_val - - (params, _, _), history = jax.lax.scan( - step, (params, hyper_state, moment_state), (iter_nums, keys) - ) - inf_state = InferenceState(params=params, history=history) - return inf_state diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index ad1a6394..cb660cba 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -82,6 +82,53 @@ def test_batch_fitting(n_iters, nb, ndata): assert history.shape[0] == n_iters +@pytest.mark.parametrize("n_iters", [5]) +@pytest.mark.parametrize("nb", [1, 20, 50]) +@pytest.mark.parametrize("ndata", [50]) +def test_natural_gradients(ndata, nb, n_iters): + key = jr.PRNGKey(123) + x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) + y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 + D = Dataset(X=x, y=y) + prior = Prior(kernel=RBF()) + likelihood = Gaussian(num_datapoints=ndata) + p = prior * likelihood + z = jnp.linspace(-2.0, 2.0, 10).reshape(-1, 1) + + q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) + + svgp = gpx.StochasticVI(posterior=p, variational_family=q) + params, trainable_status, constrainer, unconstrainer = initialise( + svgp, key + ).unpack() + params = gpx.transform(params, unconstrainer) + + D = Dataset(X=x, y=y) + + hyper_optimiser = optax.adam(learning_rate=0.1) + moment_optimiser = optax.sgd(learning_rate=1.0) + + key = jr.PRNGKey(42) + inference_state = fit_natgrads( + svgp, + params, + trainable_status, + constrainer, + D, + moment_optimiser, + hyper_optimiser, + key, + nb, + n_iters, + ) + optimised_params, history = inference_state.params, inference_state.history + optimised_params = transform(optimised_params, constrainer) + assert isinstance(inference_state, InferenceState) + assert isinstance(optimised_params, dict) + assert isinstance(history, jnp.ndarray) + assert history.shape[0] == n_iters + + @pytest.mark.parametrize("batch_size", [1, 2, 50]) @pytest.mark.parametrize("ndim", [1, 2, 3]) @pytest.mark.parametrize("ndata", [50]) From f45d99b8bfa1e89ea9f505465c29698b0f6ca34f Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Wed, 24 Aug 2022 17:06:12 +0100 Subject: [PATCH 58/66] Address comments. --- examples/natgrads.ipynb | 29 ++++++++++++++++++++++++++- gpjax/natural_gradients.py | 24 ++++++++-------------- gpjax/parameters.py | 41 ++++++++++---------------------------- tests/test_parameters.py | 2 ++ 4 files changed, 48 insertions(+), 48 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index e6315c51..c16d9113 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -20,6 +20,19 @@ "\"The ordinary gradient turns out to be an unnatural direction to follow for variational inference since we are optimizing a distribution, rather than a set of pa- rameters directly. One way to define the gradient is the direction that achieves maximum change subject to a perturbation within a small euclidean ball. To see why the euclidean distance is an unnatural metric for probability distributions, consider the two Gaussians $\\mathcal{N}(0, 0.1)$ and $\\mathcal{N}(0, 0.2)$, compared to $\\mathcal{N}(0, 1000.1)$ and $\\mathcal{N}N(0,1000.2)$.\"" ] }, + { + "cell_type": "markdown", + "id": "a889abb6", + "metadata": {}, + "source": [ + "# Mathematical background\n", + "\n", + "Gradient descent algorithms seek to minimise a function $g(\\boldsymbol{\\theta})$ through a sequence of updates $\\{\\boldsymbol{\\theta}_t\\}_{t}$ via an iterative proceedure, \n", + "\\begin{align}\n", + " \\boldsymbol{\\theta}_{t+1} \\gets \\boldsymbol{\\theta}_{t} + \n", + "\\end{align}" + ] + }, { "cell_type": "code", "execution_count": null, @@ -52,7 +65,11 @@ "id": "6f7facf2", "metadata": {}, "source": [ - "Generate dataset:" + "We simulate a dataset $\\mathcal{D} = (\\boldsymbol{x}, \\boldsymbol{y}) = \\{(x_i, y_i)\\}_{i=1}^{5000}$ with inputs $\\boldsymbol{x}$ sampled uniformly on $(-5, 5)$ and corresponding binary outputs\n", + "\n", + "$$\\boldsymbol{y} \\sim \\mathcal{N} \\left(\\sin(4 * \\boldsymbol{x}) + \\sin(2 * \\boldsymbol{x}), \\textbf{I} * (0.2)^{2} \\right).$$\n", + "\n", + "We store our data $\\mathcal{D}$ as a GPJax `Dataset` and create test inputs for later." ] }, { @@ -190,6 +207,16 @@ "learned_params = gpx.transform(learned_params, constrainers)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "06b2fb33", + "metadata": {}, + "outputs": [], + "source": [ + "elbo = svgp.elbo(D, learned_params, negative=True)" + ] + }, { "cell_type": "markdown", "id": "fbcdd41c", diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 1b75b01b..3217863c 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -2,20 +2,13 @@ from typing import Callable, Dict, Tuple import jax.numpy as jnp -import jax.random as jr import jax.scipy as jsp from jax import value_and_grad from jaxtyping import f64 from .config import get_defaults from .gps import AbstractPosterior -from .parameters import ( - build_identity, - build_trainables_false, - build_trainables_true, - trainable_params, - transform, -) +from .parameters import build_trainables, trainable_params, transform from .types import Dataset from .utils import I from .variational_families import ( @@ -105,9 +98,8 @@ def _expectation_elbo( svgp = StochasticVI( posterior=posterior, variational_family=expectation_vartiational_gaussian ) - identity_transformation = build_identity(svgp._initialise_params(jr.PRNGKey(123))) - return svgp.elbo(train_data, identity_transformation, negative=True) + return svgp.elbo(train_data, transformations=None, negative=True) def _stop_gradients_nonmoments(params: Dict) -> Dict: @@ -118,8 +110,8 @@ def _stop_gradients_nonmoments(params: Dict) -> Dict: Returns: Dict: A dictionary of parameters with stopped gradients. """ - trainables = build_trainables_false(params) - moment_trainables = build_trainables_true(params["variational_family"]["moments"]) + trainables = build_trainables(params, False) + moment_trainables = build_trainables(params["variational_family"]["moments"], True) trainables["variational_family"]["moments"] = moment_trainables params = trainable_params(params, trainables) return params @@ -133,8 +125,8 @@ def _stop_gradients_moments(params: Dict) -> Dict: Returns: Dict: A dictionary of parameters with stopped gradients. """ - trainables = build_trainables_true(params) - moment_trainables = build_trainables_false(params["variational_family"]["moments"]) + trainables = build_trainables(params, True) + moment_trainables = build_trainables(params["variational_family"]["moments"], False) trainables["variational_family"]["moments"] = moment_trainables params = trainable_params(params, trainables) return params @@ -193,8 +185,8 @@ def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. trains = deepcopy(trainables) - trains["variational_family"]["moments"] = build_trainables_true( - params["variational_family"]["moments"] + trains["variational_family"]["moments"] = build_trainables( + params["variational_family"]["moments"], True ) params = trainable_params(params, trains) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index b451d159..4376cbef 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -165,26 +165,20 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ map = lambda param, trans: trans.inverse(param) return jax.tree_util.tree_map(map, params, bijectors) + if transform_map is None: + return params -def build_identity(params: Dict) -> Dict: - """ " - Args: - params (Dict): The parameter set for which trainable statuses should be derived from. - - Returns: - Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - - return jax.tree_util.tree_map(lambda _: Identity.forward, prior_container) + else: + return jax.tree_util.tree_map( + lambda param, trans: trans(param), params, transform_map + ) ################################ @@ -263,11 +257,12 @@ def prior_checks(priors: Dict) -> Dict: return priors -def build_trainables_true(params: Dict) -> Dict: +def build_trainables(params: Dict, status: bool = True) -> Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is trainable. Args: params (Dict): The parameter set for which trainable statuses should be derived from. + status (bool): The status of each parameter. Default is True. Returns: Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. @@ -275,23 +270,7 @@ def build_trainables_true(params: Dict) -> Dict: # Copy dictionary structure prior_container = deepcopy(params) # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: True, prior_container) - return prior_container - - -def build_trainables_false(params: Dict) -> Dict: - """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is NOT trainable. - - Args: - params (Dict): The parameter set for which trainable statuses should be derived from. - - Returns: - Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. - """ - # Copy dictionary structure - prior_container = deepcopy(params) - # Set all values to zero - prior_container = jax.tree_util.tree_map(lambda _: False, prior_container) + prior_container = jax.tree_util.tree_map(lambda _: status, prior_container) return prior_container diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 18fd2e3e..76286213 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -12,6 +12,8 @@ from gpjax.parameters import ( build_bijectors, constrain, + build_trainables, + build_transforms, copy_dict_structure, evaluate_priors, initialise, From 8f4923d6b99dabe4f911515e5d10a571bd155ddb Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 25 Aug 2022 16:33:16 +0100 Subject: [PATCH 59/66] Address review. --- gpjax/abstractions.py | 8 +- gpjax/natural_gradients.py | 73 +++++++------- tests/test_natural_gradients.py | 174 ++++++++++++++++++++++++++------ 3 files changed, 183 insertions(+), 72 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 34717b10..e012db7b 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -224,7 +224,7 @@ def fit_natgrads( ) -> Dict: """This is a training loop for natural gradients. See Salimbeni et al. (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models - We begin with an initalise natural gradient step to tighten the ELBO for hyperparameter optimisation. There after, each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior. + Each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior. Args: stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training. @@ -247,12 +247,6 @@ def fit_natgrads( stochastic_vi, train_data, transformations ) - # Initial natural gradient step to improve bound for hyperparameters: - batch = get_batch(train_data, batch_size, key) - loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) - updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) - params = optax.apply_updates(params, updates) - keys = jax.random.split(key, n_iters) iter_nums = jnp.arange(n_iters) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 3217863c..a83df62b 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -21,9 +21,7 @@ DEFAULT_JITTER = get_defaults()["jitter"] -def natural_to_expectation( - natural_moments: Dict, jitter: float = DEFAULT_JITTER -) -> Dict: +def natural_to_expectation(params: Dict, jitter: float = DEFAULT_JITTER) -> Dict: """ Translate natural parameters to expectation parameters. @@ -40,14 +38,14 @@ def natural_to_expectation( Now from here, using μ and S found from θ, we compute η as η₁ = μ, and η₂ = S + μ μᵀ. Args: - natural_moments: A dictionary of natural parameters. + params: A dictionary of variational Gaussian parameters under the natural parameterisation. jitter (float): A small value to prevent numerical instability. Returns: Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ - natural_matrix = natural_moments["natural_matrix"] - natural_vector = natural_moments["natural_vector"] + natural_matrix = params["variational_family"]["moments"]["natural_matrix"] + natural_vector = params["variational_family"]["moments"]["natural_vector"] m = natural_vector.shape[0] # S⁻¹ = -2θ₂ @@ -72,11 +70,13 @@ def natural_to_expectation( # η₂ = S + μ μᵀ expectation_matrix = S + jnp.matmul(mu, mu.T) - return { + params["variational_family"]["moments"] = { "expectation_vector": expectation_vector, "expectation_matrix": expectation_matrix, } + return params + def _expectation_elbo( posterior: AbstractPosterior, @@ -132,6 +132,28 @@ def _stop_gradients_moments(params: Dict) -> Dict: return params +# TODO: Write unit test: +def _rename_expectation_to_natural(params: Dict) -> Dict: + """This function renames the gradient components (that have expectation parameterisation keys) to match the natural parameterisation pytree.""" + params["variational_family"]["moments"] = { + "natural_vector": params["variational_family"]["moments"]["expectation_vector"], + "natural_matrix": params["variational_family"]["moments"]["expectation_matrix"], + } + + return params + + +# TODO: Write unit test: +def _rename_natural_to_expectation(params: Dict) -> Dict: + """This function renames the gradient components (that have natural parameterisation keys) to match the expectation parameterisation pytree.""" + params["variational_family"]["moments"] = { + "expectation_vector": params["variational_family"]["moments"]["natural_vector"], + "expectation_matrix": params["variational_family"]["moments"]["natural_matrix"], + } + + return params + + def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, @@ -171,42 +193,23 @@ def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: # Transform parameters to constrained space. params = transform(params, transformations) - # Get natural moments θ. - natural_moments = params["variational_family"]["moments"] - - # Get expectation moments η. - expectation_moments = natural_to_expectation(natural_moments) - - # Full params with expectation moments. - expectation_params = deepcopy(params) - expectation_params["variational_family"]["moments"] = expectation_moments + # Convert natural parameterisation θ to the expectation parametersation η. + expectation_params = natural_to_expectation(params) # Compute gradient ∂L/∂η: def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: - # Determine hyperparameters that should be trained. - trains = deepcopy(trainables) - trains["variational_family"]["moments"] = build_trainables( - params["variational_family"]["moments"], True + # Stop gradients for non-trainable and non-moment parameters. + expectation_trainables = _rename_natural_to_expectation( + deepcopy(trainables) ) - params = trainable_params(params, trains) - - # Stop gradients for non-moment parameters. + params = trainable_params(params, expectation_trainables) params = _stop_gradients_nonmoments(params) return expectation_elbo(params, batch) value, dL_dexp = value_and_grad(loss_fn)(expectation_params, batch) - # This is a renaming of the gradient components to match the natural parameterisation pytree. - nat_grad = dL_dexp - nat_grad["variational_family"]["moments"] = { - "natural_vector": dL_dexp["variational_family"]["moments"][ - "expectation_vector" - ], - "natural_matrix": dL_dexp["variational_family"]["moments"][ - "expectation_matrix" - ], - } + nat_grad = _rename_expectation_to_natural(dL_dexp) return value, nat_grad @@ -225,10 +228,8 @@ def hyper_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: """ def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: - # Determine hyperparameters that should be trained. + # Stop gradients for non-trainable and moment parameters. params = trainable_params(params, trainables) - - # Stop gradients for the moment parameters. params = _stop_gradients_moments(params) return xi_elbo(params, batch) diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index 4a36119c..e84ccb32 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -9,6 +9,8 @@ from gpjax.abstractions import get_batch from gpjax.natural_gradients import ( _expectation_elbo, + _rename_expectation_to_natural, + _rename_natural_to_expectation, natural_gradients, natural_to_expectation, ) @@ -17,6 +19,17 @@ key = jr.PRNGKey(123) +def get_data_and_gp(n_datapoints): + x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) + y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 + D = gpx.Dataset(X=x, y=y) + + p = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=n_datapoints) + post = p * lik + return D, post, p + + @pytest.mark.parametrize("dim", [1, 2, 3]) def test_natural_to_expectation(dim): """ @@ -28,37 +41,148 @@ def test_natural_to_expectation(dim): tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ - natural_matrix = -0.5 * jnp.eye(dim) - natural_vector = jnp.zeros((dim, 1)) + _, posterior, prior = get_data_and_gp(10) + + z = jnp.linspace(-5.0, 5.0, 5 * dim).reshape(-1, dim) + expectation_variational_family = ( + gpx.variational_families.ExpectationVariationalGaussian( + prior=prior, inducing_inputs=z + ) + ) + + natural_variational_family = gpx.variational_families.NaturalVariationalGaussian( + prior=prior, inducing_inputs=z + ) + + natural_svgp = gpx.StochasticVI( + posterior=posterior, variational_family=natural_variational_family + ) + expectation_svgp = gpx.StochasticVI( + posterior=posterior, variational_family=expectation_variational_family + ) - natural_moments = { - "natural_matrix": natural_matrix, - "natural_vector": natural_vector, - } + key = jr.PRNGKey(123) + natural_params, *_ = gpx.initialise(natural_svgp, key).unpack() + expectation_params, *_ = gpx.initialise(expectation_svgp, key).unpack() - expectation_moments = natural_to_expectation(natural_moments, jitter=1e-6) + expectation_params_test = natural_to_expectation(natural_params, jitter=1e-6) - assert "expectation_vector" in expectation_moments.keys() - assert "expectation_matrix" in expectation_moments.keys() assert ( - expectation_moments["expectation_vector"].shape - == natural_moments["natural_vector"].shape + "expectation_vector" + in expectation_params_test["variational_family"]["moments"].keys() ) assert ( - expectation_moments["expectation_matrix"].shape - == natural_moments["natural_matrix"].shape + "expectation_matrix" + in expectation_params_test["variational_family"]["moments"].keys() + ) + assert ( + expectation_params_test["variational_family"]["moments"][ + "expectation_vector" + ].shape + == expectation_params["variational_family"]["moments"][ + "expectation_vector" + ].shape + ) + assert ( + expectation_params_test["variational_family"]["moments"][ + "expectation_matrix" + ].shape + == expectation_params["variational_family"]["moments"][ + "expectation_matrix" + ].shape ) -def get_data_and_gp(n_datapoints): - x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) - y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 - D = gpx.Dataset(X=x, y=y) +from copy import deepcopy - p = gpx.Prior(kernel=gpx.RBF()) - lik = gpx.Gaussian(num_datapoints=n_datapoints) - post = p * lik - return D, post, p + +def test_renaming(): + """ + Converts natural parameters to expectation parameters. + Args: + natural_moments: A dictionary of natural parameters. + jitter (float): A small value to prevent numerical instability. + Returns: + tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. + """ + + _, posterior, prior = get_data_and_gp(10) + + z = jnp.linspace(-5.0, 5.0, 5).reshape(-1, 1) + expectation_variational_family = ( + gpx.variational_families.ExpectationVariationalGaussian( + prior=prior, inducing_inputs=z + ) + ) + + natural_variational_family = gpx.variational_families.NaturalVariationalGaussian( + prior=prior, inducing_inputs=z + ) + + natural_svgp = gpx.StochasticVI( + posterior=posterior, variational_family=natural_variational_family + ) + expectation_svgp = gpx.StochasticVI( + posterior=posterior, variational_family=expectation_variational_family + ) + + key = jr.PRNGKey(123) + natural_params, *_ = gpx.initialise(natural_svgp, key).unpack() + expectation_params, *_ = gpx.initialise(expectation_svgp, key).unpack() + + _nat = deepcopy(natural_params) + _exp = deepcopy(expectation_params) + + rename_expectation_to_natural = _rename_expectation_to_natural(_exp) + rename_natural_to_expectation = _rename_natural_to_expectation(_nat) + + # Check correct names are in the dictionaries: + assert ( + "expectation_vector" + in rename_natural_to_expectation["variational_family"]["moments"].keys() + ) + assert ( + "expectation_matrix" + in rename_natural_to_expectation["variational_family"]["moments"].keys() + ) + assert ( + "natural_vector" + not in rename_natural_to_expectation["variational_family"]["moments"].keys() + ) + assert ( + "natural_matrix" + not in rename_natural_to_expectation["variational_family"]["moments"].keys() + ) + + assert ( + "natural_vector" + in rename_expectation_to_natural["variational_family"]["moments"].keys() + ) + assert ( + "natural_matrix" + in rename_expectation_to_natural["variational_family"]["moments"].keys() + ) + assert ( + "expectation_vector" + not in rename_expectation_to_natural["variational_family"]["moments"].keys() + ) + assert ( + "expectation_matrix" + not in rename_expectation_to_natural["variational_family"]["moments"].keys() + ) + + # Check the values are unchanged: + for v1, v2 in zip( + rename_natural_to_expectation["variational_family"]["moments"].values(), + natural_params["variational_family"]["moments"].values(), + ): + assert jnp.all(v1 == v2) + + for v1, v2 in zip( + rename_expectation_to_natural["variational_family"]["moments"].values(), + expectation_params["variational_family"]["moments"].values(), + ): + assert jnp.all(v1 == v2) @pytest.mark.parametrize("jit_fns", [True, False]) @@ -96,14 +220,6 @@ def test_expectation_elbo(jit_fns): assert len(grads) == len(params) -# def test_stop_gradients_nonmoments(): -# pass - - -# def test_stop_gradients_moments(): -# pass - - def test_natural_gradients(): """ Tests the natural gradient and hyperparameter gradients. From bd64f3a5bd159f6ad6f03f532692a9658a81cd77 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 25 Aug 2022 17:05:39 +0100 Subject: [PATCH 60/66] Create skeleton notebook. --- docs/refs.bib | 16 ++++++++++++ examples/natgrads.ipynb | 57 +++++------------------------------------ 2 files changed, 22 insertions(+), 51 deletions(-) diff --git a/docs/refs.bib b/docs/refs.bib index 2d52ff07..4f7a77cb 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -79,3 +79,19 @@ @InProceedings{titsias2009 series = {Proceedings of Machine Learning Research}, publisher = {PMLR}, } + +@misc{salimbeni2018, + doi = {10.48550/ARXIV.1803.09151}, + + url = {https://arxiv.org/abs/1803.09151}, + + author = {Salimbeni, Hugh and Eleftheriadis, Stefanos and Hensman, James}, + + keywords = {Machine Learning (stat.ML), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + + title = {Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models}, + + publisher = {arXiv}, + + year = {2018}, +} diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index c16d9113..89dca6a4 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -13,24 +13,7 @@ "id": "02dcd16f", "metadata": {}, "source": [ - "In this notebook we demonstrate how to implement natural gradients. \n", - "\n", - "As well explained in Salimbeni et al. (2018),\n", - "\n", - "\"The ordinary gradient turns out to be an unnatural direction to follow for variational inference since we are optimizing a distribution, rather than a set of pa- rameters directly. One way to define the gradient is the direction that achieves maximum change subject to a perturbation within a small euclidean ball. To see why the euclidean distance is an unnatural metric for probability distributions, consider the two Gaussians $\\mathcal{N}(0, 0.1)$ and $\\mathcal{N}(0, 0.2)$, compared to $\\mathcal{N}(0, 1000.1)$ and $\\mathcal{N}N(0,1000.2)$.\"" - ] - }, - { - "cell_type": "markdown", - "id": "a889abb6", - "metadata": {}, - "source": [ - "# Mathematical background\n", - "\n", - "Gradient descent algorithms seek to minimise a function $g(\\boldsymbol{\\theta})$ through a sequence of updates $\\{\\boldsymbol{\\theta}_t\\}_{t}$ via an iterative proceedure, \n", - "\\begin{align}\n", - " \\boldsymbol{\\theta}_{t+1} \\gets \\boldsymbol{\\theta}_{t} + \n", - "\\end{align}" + "In this notebook, we show how to create natural gradients. Ordinary gradient descent algorithms are an undesirable for variational inference because we are minimising the KL divergence between distributions rather than a set of parameters directly. Natural gradients, on the other hand, accounts for the curvature induced by the KL divergence that has the capacity to considerably improve performance (see e.g., Salimbeni et al. (2018) for further details)." ] }, { @@ -120,7 +103,7 @@ "id": "664c204b", "metadata": {}, "source": [ - "# Model and variational inference strategy:" + "# Natural gradients:" ] }, { @@ -128,7 +111,7 @@ "id": "ce4de494", "metadata": {}, "source": [ - "Define model, variational family and variational inference strategy:" + "We begin by defining our model, variational family and variational inference strategy:" ] }, { @@ -161,28 +144,12 @@ "loss_fn = jit(svgp.elbo(D, constrainers, negative=True))" ] }, - { - "cell_type": "markdown", - "id": "55e697ec", - "metadata": {}, - "source": [ - "Get default parameters and transform these to the uncontrained space:" - ] - }, - { - "cell_type": "markdown", - "id": "8969b14e", - "metadata": {}, - "source": [ - "# Natural gradients:" - ] - }, { "cell_type": "markdown", "id": "e793c24f", "metadata": {}, "source": [ - "Define natural gradient and hyperparameter gradient functions:" + "Next, we can conduct natural gradients as follows:" ] }, { @@ -207,22 +174,12 @@ "learned_params = gpx.transform(learned_params, constrainers)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "06b2fb33", - "metadata": {}, - "outputs": [], - "source": [ - "elbo = svgp.elbo(D, learned_params, negative=True)" - ] - }, { "cell_type": "markdown", "id": "fbcdd41c", "metadata": {}, "source": [ - "Plot results:" + "Here is the fitted model:" ] }, { @@ -262,9 +219,7 @@ "id": "649d29ec", "metadata": {}, "source": [ - "As mentioned in Hensman et al 2013, ....\n", - "\n", - "We demonstrate this now:" + "As mentioned in Hensman et al. (2013), in the case of a Gaussian likelihood, taking a step of unit length for natural gradients on a full batch of data recovers the same solution as Titsias (2009). We now illustrate this." ] }, { From f8157d0e386140947d920b9c04cf018bfe735d39 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 26 Aug 2022 11:20:29 +0100 Subject: [PATCH 61/66] Move trainable dictionaries outside of gradient functions. --- examples/natgrads.ipynb | 12 +++---- gpjax/abstractions.py | 6 ++-- gpjax/natural_gradients.py | 60 +++++++++++---------------------- tests/test_natural_gradients.py | 6 ++-- 4 files changed, 29 insertions(+), 55 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 89dca6a4..f195ca65 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -294,20 +294,20 @@ "\n", "params = gpx.transform(params, unconstrainers)\n", "\n", - "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers)\n", + "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, trainables)\n", "\n", "moment_optim = ox.sgd(1.0)\n", "\n", "moment_state = moment_optim.init(params)\n", "\n", "# Natural gradients update:\n", - "loss_val, loss_gradient = nat_grads_fn(params, trainables, D)\n", + "loss_val, loss_gradient = nat_grads_fn(params, D)\n", "print(loss_val)\n", "\n", "updates, moment_state = moment_optim.update(loss_gradient, moment_state, params)\n", "params = ox.apply_updates(params, updates)\n", "\n", - "loss_val, _ = nat_grads_fn(params, trainables, D)\n", + "loss_val, _ = nat_grads_fn(params, D)\n", "\n", "print(loss_val)" ] @@ -327,16 +327,12 @@ "metadata": {}, "outputs": [], "source": [ - "from gpjax.parameters import build_identity\n", - "\n", "q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)\n", "sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)\n", "\n", "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", "\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n", - "loss_fn = sgpr.elbo(D, constrainers, negative=True)\n", + "loss_fn = sgpr.elbo(D, transformations=None, negative=True)\n", "\n", "loss_val = loss_fn(params)\n", "\n", diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index e012db7b..18b8f50d 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -244,7 +244,7 @@ def fit_natgrads( moment_state = moment_optim.init(params) nat_grads_fn, hyper_grads_fn = natural_gradients( - stochastic_vi, train_data, transformations + stochastic_vi, train_data, transformations, trainables ) keys = jax.random.split(key, n_iters) @@ -258,12 +258,12 @@ def step(carry, iter_num__and__key): batch = get_batch(train_data, batch_size, key) # Hyper-parameters update: - loss_val, loss_gradient = hyper_grads_fn(params, trainables, batch) + loss_val, loss_gradient = hyper_grads_fn(params, batch) updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) params = optax.apply_updates(params, updates) # Natural gradients update: - loss_val, loss_gradient = nat_grads_fn(params, trainables, batch) + loss_val, loss_gradient = nat_grads_fn(params, batch) updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) params = optax.apply_updates(params, updates) diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index a83df62b..aaac2458 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -2,6 +2,7 @@ from typing import Callable, Dict, Tuple import jax.numpy as jnp +import jax.random as jr import jax.scipy as jsp from jax import value_and_grad from jaxtyping import f64 @@ -102,37 +103,6 @@ def _expectation_elbo( return svgp.elbo(train_data, transformations=None, negative=True) -def _stop_gradients_nonmoments(params: Dict) -> Dict: - """ - Stops gradients for non-moment parameters. - Args: - params: A dictionary of parameters. - Returns: - Dict: A dictionary of parameters with stopped gradients. - """ - trainables = build_trainables(params, False) - moment_trainables = build_trainables(params["variational_family"]["moments"], True) - trainables["variational_family"]["moments"] = moment_trainables - params = trainable_params(params, trainables) - return params - - -def _stop_gradients_moments(params: Dict) -> Dict: - """ - Stops gradients for moment parameters. - Args: - params: A dictionary of parameters. - Returns: - Dict: A dictionary of parameters with stopped gradients. - """ - trainables = build_trainables(params, True) - moment_trainables = build_trainables(params["variational_family"]["moments"], False) - trainables["variational_family"]["moments"] = moment_trainables - params = trainable_params(params, trainables) - return params - - -# TODO: Write unit test: def _rename_expectation_to_natural(params: Dict) -> Dict: """This function renames the gradient components (that have expectation parameterisation keys) to match the natural parameterisation pytree.""" params["variational_family"]["moments"] = { @@ -143,7 +113,6 @@ def _rename_expectation_to_natural(params: Dict) -> Dict: return params -# TODO: Write unit test: def _rename_natural_to_expectation(params: Dict) -> Dict: """This function renames the gradient components (that have natural parameterisation keys) to match the expectation parameterisation pytree.""" params["variational_family"]["moments"] = { @@ -158,6 +127,7 @@ def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, transformations: Dict, + trainables: Dict, ) -> Tuple[Callable[[Dict, Dataset], Dict]]: """ Computes the gradient with respect to the natural parameters. Currently only implemented for the natural variational Gaussian family. @@ -178,9 +148,22 @@ def natural_gradients( # The ELBO under the expectation parameterisation, L(η). expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) + # Stop nonment params: + expectation_trainables = _rename_natural_to_expectation(deepcopy(trainables)) + moment_trainables = build_trainables(expectation_trainables, False) + moment_trainables["variational_family"]["moments"] = expectation_trainables[ + "variational_family" + ]["moments"] + + # Stop moment params: + hyper_trainables = deepcopy(trainables) + hyper_trainables["variational_family"]["moments"] = build_trainables( + trainables["variational_family"]["moments"], False + ) + if isinstance(variational_family, NaturalVariationalGaussian): - def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: + def nat_grads_fn(params: Dict, batch: Dataset) -> Dict: """ Computes the natural gradients of the ELBO. Args: @@ -199,11 +182,7 @@ def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: # Compute gradient ∂L/∂η: def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Stop gradients for non-trainable and non-moment parameters. - expectation_trainables = _rename_natural_to_expectation( - deepcopy(trainables) - ) - params = trainable_params(params, expectation_trainables) - params = _stop_gradients_nonmoments(params) + params = trainable_params(params, moment_trainables) return expectation_elbo(params, batch) @@ -216,7 +195,7 @@ def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: else: raise NotImplementedError - def hyper_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: + def hyper_grads_fn(params: Dict, batch: Dataset) -> Dict: """ Computes the hyperparameter gradients of the ELBO. Args: @@ -229,8 +208,7 @@ def hyper_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Stop gradients for non-trainable and moment parameters. - params = trainable_params(params, trainables) - params = _stop_gradients_moments(params) + params = trainable_params(params, hyper_trainables) return xi_elbo(params, batch) diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index e84ccb32..e8a2dd00 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -239,13 +239,13 @@ def test_natural_gradients(): batch = get_batch(D, batch_size=10, key=jr.PRNGKey(42)) - nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers) + nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, trainables) assert isinstance(nat_grads_fn, tp.Callable) assert isinstance(hyper_grads_fn, tp.Callable) - val, nat_grads = nat_grads_fn(params, trainables, batch) - val, hyper_grads = hyper_grads_fn(params, trainables, batch) + val, nat_grads = nat_grads_fn(params, batch) + val, hyper_grads = hyper_grads_fn(params, batch) assert isinstance(val, jnp.ndarray) assert isinstance(nat_grads, tp.Dict) From 9d92f5a22cbd1ce8c0f2a7cac34cb55a40c45ff7 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Fri, 26 Aug 2022 17:11:42 +0100 Subject: [PATCH 62/66] Update documentation. --- gpjax/gps.py | 4 +-- gpjax/variational_families.py | 51 ++++++++++++++++++++++++++++++----- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/gpjax/gps.py b/gpjax/gps.py index a8a748ae..e70c1db6 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -177,10 +177,10 @@ def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # L⁻¹ Kxt L_inv_Kxt = jsp.linalg.solve_triangular(L, Kxt, lower=True) - # μt + Ktx (Kzz + Iσ²)⁻¹ (y - μx) + # μt + Ktx (Kxx + Iσ²)⁻¹ (y - μx) mean = μt + jnp.matmul(L_inv_Kxt.T, w) - # Ktt - Ktz (Kzz + Iσ²)⁻¹ Kxt [recall (Kzz + Iσ²)⁻¹ = (LLᵀ)⁻¹ = L⁻ᵀL⁻¹] + # Ktt - Ktx (Kxx + Iσ²)⁻¹ Kxt [recall (Kxx + Iσ²)⁻¹ = (LLᵀ)⁻¹ = L⁻ᵀL⁻¹] covariance = Ktt - jnp.matmul(L_inv_Kxt.T, L_inv_Kxt) covariance += I(n_test) * self.jitter diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 01ee3e2e..01a18515 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -164,7 +164,7 @@ class WhitenedVariationalGaussian(VariationalGaussian): """The whitened variational Gaussian family of probability distributions. The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z - and the distribution over the inducing inputs is q(u) = N(Lz μ + mz, Lz S). We parameterise this over μ and sqrt with S = sqrt sqrtᵀ. + and the distribution over the inducing inputs is q(u) = N(Lz μ + mz, Lz S Lzᵀ). We parameterise this over μ and sqrt with S = sqrt sqrtᵀ. """ @@ -245,7 +245,14 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: @dataclass class NaturalVariationalGaussian(AbstractVariationalGaussian): - """The natural variational Gaussian family of probability distributions.""" + """The natural variational Gaussian family of probability distributions. + + The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z + and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the + exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform + model inference, where T(u) = [u, uuᵀ] are the sufficient statistics. + + """ name: str = "Natural Gaussian" @@ -270,6 +277,10 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: def prior_kl(self, params: Dict) -> f64["1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. + For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], + + with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2). + Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. @@ -303,7 +314,13 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: - """Compute the predictive distribution of the GP at the test inputs. + """Compute the predictive distribution of the GP at the test inputs t. + + This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as + + N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ], + + with μ and S computed from the natural paramerisation θ = (S⁻¹μ, -S⁻¹/2). Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. @@ -371,7 +388,14 @@ def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: @dataclass class ExpectationVariationalGaussian(AbstractVariationalGaussian): - """The variational Gaussian family of probability distributions.""" + """The natural variational Gaussian family of probability distributions. + + The variational family is q(f(·)) = ∫ p(f(·)|u) q(u) du, where u = f(z) are the function values at the inducing inputs z + and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the + exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2) and + sufficient stastics T(u) = [u, uuᵀ]. The expectation parameters are given by η = ∫ T(u) q(u) du. This gives a parameterisation, + η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over. + """ name: str = "Expectation Gaussian" @@ -398,6 +422,10 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: def prior_kl(self, params: Dict) -> f64["1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. + For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], + + with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ). + Args: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. @@ -413,9 +441,14 @@ def prior_kl(self, params: Dict) -> f64["1"]: z = params["variational_family"]["inducing_inputs"] m = self.num_inducing + # μ = η₁ mu = expectation_vector - S = expectation_matrix - jnp.matmul(mu, mu.T) + + # S = η₂ - η₁ η₁ᵀ + S = expectation_matrix - jnp.outer(mu, mu) S += I(m) * self.jitter + + # S = sqrt sqrtᵀ sqrt = jnp.linalg.cholesky(S) μz = self.prior.mean_function(z, params["mean_function"]) @@ -429,7 +462,13 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: - """Compute the predictive distribution of the GP at the test inputs. + """Compute the predictive distribution of the GP at the test inputs t. + + This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as + + N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ], + + with μ and S computed from the expectation paramerisation η = (μ, S + uuᵀ). Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. From f1d3f81ab8d5a9340da529fac1ce8b8fb828229e Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 18 Sep 2022 13:10:23 +0100 Subject: [PATCH 63/66] Finish rebase issues. This commit finishes rebase issues, need to refactor code before merging to v0.5_update. --- examples/classification.ipynb | 4 +-- examples/natgrads.ipynb | 49 +++++++++++++++------------------ gpjax/abstractions.py | 45 +++++++++++++++--------------- gpjax/gps.py | 18 ++++++------ gpjax/natural_gradients.py | 19 +++++++------ gpjax/parameters.py | 10 +++---- gpjax/variational_families.py | 32 ++++++++++----------- tests/test_abstractions.py | 10 ++----- tests/test_natural_gradients.py | 27 ++---------------- tests/test_parameters.py | 3 +- 10 files changed, 93 insertions(+), 124 deletions(-) diff --git a/examples/classification.ipynb b/examples/classification.ipynb index 94cf9c28..14d80849 100644 --- a/examples/classification.ipynb +++ b/examples/classification.ipynb @@ -27,7 +27,7 @@ "import distrax as dx\n", "from gpjax.utils import I\n", "import jax.scipy as jsp\n", - "from jaxtyping import f64\n", + "from jaxtyping import Float, Array\n", "\n", "key = jr.PRNGKey(123)" ] @@ -294,7 +294,7 @@ "metadata": {}, "outputs": [], "source": [ - "def construct_laplace(test_inputs: f64[\"N D\"]) -> dx.MultivariateNormalFullCovariance:\n", + "def construct_laplace(test_inputs: Float[Array, \"N D\"]) -> dx.MultivariateNormalFullCovariance:\n", " \n", " map_latent_dist = posterior(D, map_estimate)(test_inputs)\n", "\n", diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index f195ca65..efb94386 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -28,9 +28,7 @@ "import matplotlib.pyplot as plt\n", "from jax import jit, lax\n", "import optax as ox\n", - "\n", "import gpjax as gpx\n", - "from gpjax.abstractions import progress_bar_scan\n", "\n", "key = jr.PRNGKey(123)" ] @@ -127,21 +125,18 @@ "p = prior * likelihood\n", "\n", "\n", - "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", - "svgp = gpx.StochasticVI(posterior=p, variational_family=q)" + "natural_q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", + "natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)" ] }, { "cell_type": "code", "execution_count": null, - "id": "5190b12d", + "id": "60293d59", "metadata": {}, "outputs": [], "source": [ - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", - "params = gpx.transform(params, unconstrainers)\n", - "\n", - "loss_fn = jit(svgp.elbo(D, constrainers, negative=True))" + "params, trainables, bijectors = gpx.initialise(natural_svgp).unpack()" ] }, { @@ -159,19 +154,19 @@ "metadata": {}, "outputs": [], "source": [ - "learned_params, training_history = gpx.fit_natgrads(svgp,\n", - " params = params,\n", - " trainables = trainables, \n", - " transformations = constrainers,\n", - " train_data = D,\n", - " n_iters = 10000,\n", - " batch_size=100,\n", - " key = jr.PRNGKey(42),\n", - " moment_optim = ox.sgd(1.0),\n", - " hyper_optim = ox.adam(1e-3),\n", - " ).unpack()\n", + "inference_state = gpx.fit_natgrads(natural_svgp,\n", + " params,\n", + " trainables,\n", + " bijectors,\n", + " train_data = D,\n", + " n_iters = 10000,\n", + " batch_size=1000,\n", + " key = jr.PRNGKey(42),\n", + " moment_optim = ox.sgd(1.0),\n", + " hyper_optim = ox.adam(1e-3),\n", + " )\n", "\n", - "learned_params = gpx.transform(learned_params, constrainers)" + "learned_params, training_history = inference_state.unpack()" ] }, { @@ -189,7 +184,7 @@ "metadata": {}, "outputs": [], "source": [ - "latent_dist = q(learned_params)(xtest)\n", + "latent_dist = natural_q(learned_params)(xtest)\n", "predictive_dist = likelihood(latent_dist, learned_params)\n", "\n", "meanf = predictive_dist.mean()\n", @@ -290,11 +285,11 @@ "\n", "q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", "svgp = gpx.StochasticVI(posterior=p, variational_family=q)\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", + "params, trainables, bijectors = gpx.initialise(svgp).unpack()\n", "\n", - "params = gpx.transform(params, unconstrainers)\n", + "params = gpx.unconstrain(params, bijectors)\n", "\n", - "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, trainables)\n", + "nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, bijectors, trainables)\n", "\n", "moment_optim = ox.sgd(1.0)\n", "\n", @@ -330,9 +325,9 @@ "q = gpx.CollapsedVariationalGaussian(prior=prior, likelihood=likelihood, inducing_inputs=z)\n", "sgpr = gpx.CollapsedVI(posterior=p, variational_family=q)\n", "\n", - "params, trainables, constrainers, unconstrainers = gpx.initialise(svgp).unpack()\n", + "params, _, _ = gpx.initialise(svgp).unpack()\n", "\n", - "loss_fn = sgpr.elbo(D, transformations=None, negative=True)\n", + "loss_fn = sgpr.elbo(D, negative=True)\n", "\n", "loss_val = loss_fn(params)\n", "\n", diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 18b8f50d..0f235385 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -10,15 +10,15 @@ from jaxtyping import Array, Float from tqdm.auto import tqdm +from .natural_gradients import natural_gradients from .parameters import ParameterState, constrain, trainable_params, unconstrain -from .parameters import trainable_params, transform from .types import Dataset, PRNGKeyType from .variational_inference import StochasticVI @dataclass(frozen=True) class InferenceState: - params: tp.Dict + params: Dict history: Float[Array, "n_iters"] def unpack(self): @@ -209,11 +209,26 @@ def step(carry, iter_num__and__key): return inf_state +def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: + """Batch the data into mini-batches. + Args: + train_data (Dataset): The training dataset. + batch_size (int): The batch size. + Returns: + Dataset: The batched dataset. + """ + x, y, n = train_data.X, train_data.y, train_data.n + + indicies = jr.choice(key, n, (batch_size,), replace=True) + + return Dataset(X=x[indicies], y=y[indicies]) + + def fit_natgrads( stochastic_vi: StochasticVI, params: Dict, trainables: Dict, - transformations: Dict, + bijectors: Dict, train_data: Dataset, moment_optim, hyper_optim, @@ -223,14 +238,12 @@ def fit_natgrads( log_rate: Optional[int] = 10, ) -> Dict: """This is a training loop for natural gradients. See Salimbeni et al. (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models - Each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior. - Args: stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training. params (Dict): The parameters for which we would like to minimise our objective function with. trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. - transformations (Dict): The transformations to be applied to the parameters. + bijectors (Dict): The bijectors to be applied to the parameters. train_data (Dataset): The training dataset. batch_size(int): The batch_size. key (PRNGKeyType): The PRNG key for the mini-batch sampling. @@ -240,11 +253,13 @@ def fit_natgrads( InferenceState: A dataclass comprising optimised parameters and training history. """ + params = unconstrain(params, bijectors) + hyper_state = hyper_optim.init(params) moment_state = moment_optim.init(params) nat_grads_fn, hyper_grads_fn = natural_gradients( - stochastic_vi, train_data, transformations, trainables + stochastic_vi, train_data, bijectors, trainables ) keys = jax.random.split(key, n_iters) @@ -273,20 +288,6 @@ def step(carry, iter_num__and__key): (params, _, _), history = jax.lax.scan( step, (params, hyper_state, moment_state), (iter_nums, keys) ) + params = constrain(params, bijectors) inf_state = InferenceState(params=params, history=history) return inf_state - - -def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: - """Batch the data into mini-batches. - Args: - train_data (Dataset): The training dataset. - batch_size (int): The batch size. - Returns: - Dataset: The batched dataset. - """ - x, y, n = train_data.X, train_data.y, train_data.n - - indicies = jr.choice(key, n, (batch_size,), replace=True) - - return Dataset(X=x[indicies], y=y[indicies]) diff --git a/gpjax/gps.py b/gpjax/gps.py index e70c1db6..f6b732d3 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -19,7 +19,7 @@ NonConjugateLikelihoodType, ) from .mean_functions import AbstractMeanFunction, Zero -from .parameters import copy_dict_structure, evaluate_priors, transform +from .parameters import copy_dict_structure, evaluate_priors from .types import Dataset, PRNGKeyType from .utils import I, concat_dictionaries @@ -73,9 +73,7 @@ def __rmul__(self, other: AbstractLikelihood): """Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior.""" return self.__mul__(other) - def predict( - self, params: dict - ) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the GP's prior mean and variance. Args: params (Dict): The specific set of parameters for which the mean function should be defined for. @@ -139,8 +137,8 @@ class ConjugatePosterior(AbstractPosterior): jitter: Optional[float] = DEFAULT_JITTER def predict( - self, train_data: Dataset, params: dict - ) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]: + self, train_data: Dataset, params: Dict + ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Args: @@ -195,7 +193,7 @@ def marginal_log_likelihood( train_data: Dataset, priors: Dict = None, negative: bool = False, - ) -> tp.Callable[[dict], Float[Array, "1"]]: + ) -> Callable[[Dict], Float[Array, "1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values. Args: @@ -257,8 +255,8 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return parameters def predict( - self, train_data: Dataset, params: dict - ) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]: + self, train_data: Dataset, params: Dict + ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function. Args: @@ -302,7 +300,7 @@ def marginal_log_likelihood( train_data: Dataset, priors: Dict = None, negative: bool = False, - ) -> tp.Callable[[dict], Float[Array, "1"]]: + ) -> Callable[[Dict], Float[Array, "1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax. Args: diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index aaac2458..e3b79ce3 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -5,11 +5,11 @@ import jax.random as jr import jax.scipy as jsp from jax import value_and_grad -from jaxtyping import f64 +from jaxtyping import Array, Float from .config import get_defaults from .gps import AbstractPosterior -from .parameters import build_trainables, trainable_params, transform +from .parameters import build_trainables, constrain, trainable_params, unconstrain from .types import Dataset from .utils import I from .variational_families import ( @@ -100,7 +100,7 @@ def _expectation_elbo( posterior=posterior, variational_family=expectation_vartiational_gaussian ) - return svgp.elbo(train_data, transformations=None, negative=True) + return svgp.elbo(train_data, negative=True) def _rename_expectation_to_natural(params: Dict) -> Dict: @@ -126,7 +126,7 @@ def _rename_natural_to_expectation(params: Dict) -> Dict: def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, - transformations: Dict, + bijectors: Dict, trainables: Dict, ) -> Tuple[Callable[[Dict, Dataset], Dict]]: """ @@ -135,7 +135,7 @@ def natural_gradients( posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. train_data: A Dataset. - transformations: A dictionary of transformations. + bijectors: A dictionary of bijectors. Returns: Tuple[Callable[[Dict, Dataset], Dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. """ @@ -143,7 +143,7 @@ def natural_gradients( variational_family = stochastic_vi.variational_family # The ELBO under the user chosen parameterisation xi. - xi_elbo = stochastic_vi.elbo(train_data, transformations, negative=True) + xi_elbo = stochastic_vi.elbo(train_data, negative=True) # The ELBO under the expectation parameterisation, L(η). expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) @@ -174,13 +174,13 @@ def nat_grads_fn(params: Dict, batch: Dataset) -> Dict: Dict: A dictionary of natural gradients. """ # Transform parameters to constrained space. - params = transform(params, transformations) + params = constrain(params, bijectors) # Convert natural parameterisation θ to the expectation parametersation η. expectation_params = natural_to_expectation(params) # Compute gradient ∂L/∂η: - def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: # Stop gradients for non-trainable and non-moment parameters. params = trainable_params(params, moment_trainables) @@ -206,8 +206,9 @@ def hyper_grads_fn(params: Dict, batch: Dataset) -> Dict: Dict: A dictionary of hyperparameter gradients. """ - def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: # Stop gradients for non-trainable and moment parameters. + params = constrain(params, bijectors) params = trainable_params(params, hyper_trainables) return xi_elbo(params, batch) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 4376cbef..c6c8fb04 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -143,12 +143,12 @@ def constrain(params: Dict, bijectors: Dict) -> Dict: """Transform the parameters to the constrained space for corresponding bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + params (Dict): The parameters that are to be transformed. + transform_map (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: - tp.Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. + Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. """ map = lambda param, trans: trans.forward(param) @@ -160,8 +160,8 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: """Transform the parameters to the unconstrained space for corresponding bijectors. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + params (Dict): The parameters that are to be transformed. + transform_map (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). Returns: diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 01a18515..24371e2a 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -41,7 +41,7 @@ class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" prior: Prior - inducing_inputs: f64["N D"] + inducing_inputs: Float[Array, "N D"] name: str = "Gaussian" jitter: Optional[float] = DEFAULT_JITTER @@ -86,7 +86,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - f64["1"]: The KL-divergence between our variational approximation and the GP prior. + Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -113,7 +113,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -179,7 +179,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - f64["N D"]: The KL-divergence between our variational approximation and the GP prior. + Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -201,7 +201,7 @@ def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distributi params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] @@ -274,7 +274,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: }, ) - def prior_kl(self, params: Dict) -> f64["1"]: + def prior_kl(self, params: Dict) -> Float[Array, "1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], @@ -285,7 +285,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - f64["1"]: The KL-divergence between our variational approximation and the GP prior. + Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] @@ -313,7 +313,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -326,7 +326,7 @@ def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] @@ -354,7 +354,7 @@ def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -419,7 +419,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: }, ) - def prior_kl(self, params: Dict) -> f64["1"]: + def prior_kl(self, params: Dict) -> Float[Array, "1"]: """Compute the KL-divergence between our current variational approximation and the Gaussian process prior. For this variational family, we have KL[q(f(·))||p(·)] = KL[q(u)||p(u)] = KL[N(μ, S)||N(mz, Kzz)], @@ -430,7 +430,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: params (Dict): The parameters at which our variational distribution and GP prior are to be evaluated. Returns: - f64["1"]: The KL-divergence between our variational approximation and the GP prior. + Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ expectation_vector = params["variational_family"]["moments"][ "expectation_vector" @@ -461,7 +461,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -474,7 +474,7 @@ def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ expectation_vector = params["variational_family"]["moments"][ "expectation_vector" @@ -500,7 +500,7 @@ def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: Lz = jnp.linalg.cholesky(Kzz) μz = self.prior.mean_function(z, params["mean_function"]) - def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: + def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: t = test_inputs Ktt = gram(self.prior.kernel, t, params["kernel"]) Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) @@ -570,7 +570,7 @@ def predict( Args: params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: - Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ x, y = train_data.X, train_data.y diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index cb660cba..50259852 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -5,7 +5,7 @@ import gpjax as gpx from gpjax import RBF, Dataset, Gaussian, Prior, initialise -from gpjax.abstractions import InferenceState, fit, fit_batches, get_batch +from gpjax.abstractions import InferenceState, fit, fit_batches, fit_natgrads, get_batch from gpjax.parameters import ParameterState, build_bijectors @@ -98,10 +98,7 @@ def test_natural_gradients(ndata, nb, n_iters): q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainable_status, constrainer, unconstrainer = initialise( - svgp, key - ).unpack() - params = gpx.transform(params, unconstrainer) + params, trainable_status, bijectors = initialise(svgp, key).unpack() D = Dataset(X=x, y=y) @@ -113,7 +110,7 @@ def test_natural_gradients(ndata, nb, n_iters): svgp, params, trainable_status, - constrainer, + bijectors, D, moment_optimiser, hyper_optimiser, @@ -122,7 +119,6 @@ def test_natural_gradients(ndata, nb, n_iters): n_iters, ) optimised_params, history = inference_state.params, inference_state.history - optimised_params = transform(optimised_params, constrainer) assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert isinstance(history, jnp.ndarray) diff --git a/tests/test_natural_gradients.py b/tests/test_natural_gradients.py index e8a2dd00..06f61362 100644 --- a/tests/test_natural_gradients.py +++ b/tests/test_natural_gradients.py @@ -32,14 +32,6 @@ def get_data_and_gp(n_datapoints): @pytest.mark.parametrize("dim", [1, 2, 3]) def test_natural_to_expectation(dim): - """ - Converts natural parameters to expectation parameters. - Args: - natural_moments: A dictionary of natural parameters. - jitter (float): A small value to prevent numerical instability. - Returns: - tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. - """ _, posterior, prior = get_data_and_gp(10) @@ -97,14 +89,6 @@ def test_natural_to_expectation(dim): def test_renaming(): - """ - Converts natural parameters to expectation parameters. - Args: - natural_moments: A dictionary of natural parameters. - jitter (float): A small value to prevent numerical instability. - Returns: - tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. - """ _, posterior, prior = get_data_and_gp(10) @@ -199,9 +183,7 @@ def test_expectation_elbo(jit_fns): svgp = gpx.StochasticVI(posterior=posterior, variational_family=variational_family) - params, _, constrainer, unconstrainer = gpx.initialise( - svgp, jr.PRNGKey(123) - ).unpack() + params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() expectation_elbo = _expectation_elbo(posterior, variational_family, D) @@ -232,14 +214,11 @@ def test_natural_gradients(): svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainables, constrainers, unconstrainers = gpx.initialise( - svgp, jr.PRNGKey(123) - ).unpack() - params = gpx.transform(params, unconstrainers) + params, trainables, bijectors = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() batch = get_batch(D, batch_size=10, key=jr.PRNGKey(42)) - nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, constrainers, trainables) + nat_grads_fn, hyper_grads_fn = natural_gradients(svgp, D, bijectors, trainables) assert isinstance(nat_grads_fn, tp.Callable) assert isinstance(hyper_grads_fn, tp.Callable) diff --git a/tests/test_parameters.py b/tests/test_parameters.py index 76286213..f384d046 100644 --- a/tests/test_parameters.py +++ b/tests/test_parameters.py @@ -11,9 +11,8 @@ from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ( build_bijectors, - constrain, build_trainables, - build_transforms, + constrain, copy_dict_structure, evaluate_priors, initialise, From cc1c3183ea0c1022fe835ecffb6193c82b64d0c3 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 18 Sep 2022 18:04:07 +0100 Subject: [PATCH 64/66] Add copyright, update typing. --- examples/natgrads.ipynb | 26 +++---- gpjax/__init__.py | 56 ++++++++++++++- gpjax/abstractions.py | 68 +++++++++++++----- gpjax/gps.py | 110 +++++++++++++++++++++++------ gpjax/kernels.py | 55 ++++++++++++--- gpjax/likelihoods.py | 119 +++++++++++++++++++++++++++----- gpjax/mean_functions.py | 57 ++++++++++++--- gpjax/natural_gradients.py | 113 +++++++++++++++++++++++------- gpjax/parameters.py | 122 +++++++++++++++++++++++++++------ gpjax/quadrature.py | 30 ++++++-- gpjax/types.py | 15 ++++ gpjax/utils.py | 69 +++++++++++++------ gpjax/variational_families.py | 72 ++++++++++++++++--- gpjax/variational_inference.py | 26 ++++++- tests/test_abstractions.py | 6 +- tests/test_gp.py | 4 +- tests/test_utilities.py | 11 --- 17 files changed, 769 insertions(+), 190 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index efb94386..35cdc7cb 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -87,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n", + "z = jnp.linspace(-5.0, 5.0, 5000).reshape(-1, 1)\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.plot(x, y, \"o\", alpha=0.3)\n", @@ -126,17 +126,9 @@ "\n", "\n", "natural_q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z)\n", - "natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "60293d59", - "metadata": {}, - "outputs": [], - "source": [ - "params, trainables, bijectors = gpx.initialise(natural_svgp).unpack()" + "natural_svgp = gpx.StochasticVI(posterior=p, variational_family=natural_q)\n", + "\n", + "parameter_state = gpx.initialise(natural_svgp)" ] }, { @@ -154,13 +146,13 @@ "metadata": {}, "outputs": [], "source": [ + "\n", + "\n", "inference_state = gpx.fit_natgrads(natural_svgp,\n", - " params,\n", - " trainables,\n", - " bijectors,\n", + " parameter_state=parameter_state,\n", " train_data = D,\n", - " n_iters = 10000,\n", - " batch_size=1000,\n", + " n_iters = 4000,\n", + " batch_size=128,\n", " key = jr.PRNGKey(42),\n", " moment_optim = ox.sgd(1.0),\n", " hyper_optim = ox.adam(1e-3),\n", diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 5927d38a..661c5cb8 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -1,9 +1,22 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from jax.config import config -# Enable Float64 - this is crucial for more stable matrix inversions. +# Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) -# Highlight any potentially unintended broadcasting rank promoting ops. -# config.update("jax_numpy_rank_promotion", "warn") from .abstractions import fit, fit_batches, fit_natgrads from .gps import Prior, construct_posterior @@ -30,4 +43,41 @@ ) from .variational_inference import CollapsedVI, StochasticVI +__license__ = "MIT" +__description__ = "Didactic Gaussian processes in JAX" +__url__ = "https://github.com/thomaspinder/GPJax" +__contributors__ = "https://github.com/thomaspinder/GPJax/graphs/contributors" __version__ = "0.4.13" + + +__all__ = [ + "fit", + "fit_batches", + "fit_natgrads", + "Prior", + "construct_posterior", + "RBF", + "GraphKernel", + "Matern12", + "Matern32", + "Matern52", + "Polynomial", + "ProductKernel", + "SumKernel", + "Bernoulli", + "Gaussian", + "Constant", + "Zero", + "constrain", + "copy_dict_structure", + "initialise", + "unconstrain", + "Dataset", + "CollapsedVariationalGaussian", + "ExpectationVariationalGaussian", + "NaturalVariationalGaussian", + "VariationalGaussian", + "WhitenedVariationalGaussian", + "CollapsedVI", + "StochasticVI", +] diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 0f235385..16747738 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -1,9 +1,24 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict, Optional import jax import jax.numpy as jnp import jax.random as jr -import optax +import optax as ox from chex import dataclass from jax import lax from jax.experimental import host_callback @@ -100,18 +115,20 @@ def wrapper_progress_bar(carry, x): def fit( objective: Callable, parameter_state: ParameterState, - optax_optim, - n_iters: int = 100, - log_rate: int = 10, + optax_optim: ox.GradientTransformation, + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, ) -> InferenceState: """Abstracted method for fitting a GP model with respect to a supplied objective function. Optimisers used here should originate from Optax. + Args: objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The initial parameter state. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. + Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ @@ -135,7 +152,7 @@ def step(carry, iter_num): params, opt_state = carry loss_val, loss_gradient = jax.value_and_grad(loss)(params) updates, opt_state = optax_optim.update(loss_gradient, opt_state, params) - params = optax.apply_updates(params, updates) + params = ox.apply_updates(params, updates) carry = params, opt_state return carry, loss_val @@ -153,7 +170,7 @@ def fit_batches( objective: Callable, parameter_state: ParameterState, train_data: Dataset, - optax_optim, + optax_optim: ox.GradientTransformation, key: PRNGKeyType, batch_size: int, n_iters: Optional[int] = 100, @@ -161,6 +178,7 @@ def fit_batches( ) -> InferenceState: """Abstracted method for fitting a GP model with mini-batches respect to a supplied objective function. Optimisers used here should originate from Optax. + Args: objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. @@ -170,6 +188,7 @@ def fit_batches( batch_size(int): The batch_size. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. + Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. """ @@ -196,7 +215,7 @@ def step(carry, iter_num__and__key): loss_val, loss_gradient = jax.value_and_grad(loss)(params, batch) updates, opt_state = optax_optim.update(loss_gradient, opt_state, params) - params = optax.apply_updates(params, updates) + params = ox.apply_updates(params, updates) carry = params, opt_state return carry, loss_val @@ -211,9 +230,11 @@ def step(carry, iter_num__and__key): def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: """Batch the data into mini-batches. + Args: train_data (Dataset): The training dataset. batch_size (int): The batch size. + Returns: Dataset: The batched dataset. """ @@ -226,12 +247,10 @@ def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset def fit_natgrads( stochastic_vi: StochasticVI, - params: Dict, - trainables: Dict, - bijectors: Dict, + parameter_state: ParameterState, train_data: Dataset, - moment_optim, - hyper_optim, + moment_optim: ox.GradientTransformation, + hyper_optim: ox.GradientTransformation, key: PRNGKeyType, batch_size: int, n_iters: Optional[int] = 100, @@ -239,20 +258,24 @@ def fit_natgrads( ) -> Dict: """This is a training loop for natural gradients. See Salimbeni et al. (2018) Natural Gradients in Practice: Non-Conjugate Variational Inference in Gaussian Process Models Each iteration comprises a hyperparameter gradient step followed by natural gradient step to avoid a stale posterior. + Args: stochastic_vi (StochasticVI): The stochastic variational inference algorithm to be used for training. - params (Dict): The parameters for which we would like to minimise our objective function with. - trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. - bijectors (Dict): The bijectors to be applied to the parameters. + parameter_state (ParameterState): The initial parameter state. train_data (Dataset): The training dataset. - batch_size(int): The batch_size. + moment_optim (GradientTransformation): The Optax optimiser for the natural gradient updates on the moments. + hyper_optim (GradientTransformation): The Optax optimiser for gradient updates on the hyperparameters. key (PRNGKeyType): The PRNG key for the mini-batch sampling. + batch_size(int): The batch_size. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. + Returns: InferenceState: A dataclass comprising optimised parameters and training history. """ + params, trainables, bijectors = parameter_state.unpack() + params = unconstrain(params, bijectors) hyper_state = hyper_optim.init(params) @@ -275,12 +298,12 @@ def step(carry, iter_num__and__key): # Hyper-parameters update: loss_val, loss_gradient = hyper_grads_fn(params, batch) updates, hyper_state = hyper_optim.update(loss_gradient, hyper_state, params) - params = optax.apply_updates(params, updates) + params = ox.apply_updates(params, updates) # Natural gradients update: loss_val, loss_gradient = nat_grads_fn(params, batch) updates, moment_state = moment_optim.update(loss_gradient, moment_state, params) - params = optax.apply_updates(params, updates) + params = ox.apply_updates(params, updates) carry = params, hyper_state, moment_state return carry, loss_val @@ -291,3 +314,12 @@ def step(carry, iter_num__and__key): params = constrain(params, bijectors) inf_state = InferenceState(params=params, history=history) return inf_state + + +__all__ = [ + "fit", + "fit_natgrads", + "get_batch", + "natural_gradients", + "progress_bar_scan", +] diff --git a/gpjax/gps.py b/gpjax/gps.py index f6b732d3..29056df9 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from abc import abstractmethod from typing import Any, Callable, Dict, Optional @@ -10,14 +25,7 @@ from .config import get_defaults from .kernels import Kernel, cross_covariance, gram -from .likelihoods import ( - AbstractLikelihood, - Conjugate, - Gaussian, - NonConjugate, - NonConjugateLikelihoods, - NonConjugateLikelihoodType, -) +from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate from .mean_functions import AbstractMeanFunction, Zero from .parameters import copy_dict_structure, evaluate_priors from .types import Dataset, PRNGKeyType @@ -33,6 +41,10 @@ class AbstractGP: def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the Gaussian process at the given points. + Args: + *args (Any): The arguments to pass to the GP's `predict` method. + **kwargs (Any): The keyword arguments to pass to the GP's `predict` method. + Returns: dx.Distribution: A multivariate normal random variable representation of the Gaussian process. """ @@ -40,11 +52,26 @@ def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Compute the latent function's multivariate normal distribution.""" + """Compute the latent function's multivariate normal distribution. + + Args: + *args (Any): Arguments to the predict method. + **kwargs (Any): Keyword arguments to the predict method. + + Returns: + dx.Distribution: A multivariate normal random variable representation of the Gaussian process. + """ raise NotImplementedError def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the GP's parameter set""" + """Initialise the GP's parameter set. + + Args: + key (PRNGKeyType): The PRNG key. + + Returns: + Dict: The initialised parameter set. + """ raise NotImplementedError @@ -62,23 +89,34 @@ class Prior(AbstractGP): def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. + Args: other (Likelihood): The likelihood distribution of the observed dataset. + Returns: Posterior: The relevant GP posterior for the given prior and likelihood. Special cases are accounted for where the model is conjugate. """ return construct_posterior(prior=self, likelihood=other) def __rmul__(self, other: AbstractLikelihood): - """Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior.""" + """Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior. + + Args: + other (Likelihood): The likelihood distribution of the observed dataset. + + Returns: + Posterior: The relevant GP posterior for the given prior and likelihood. Special cases are accounted for where the model is conjugate. + """ return self.__mul__(other) def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the GP's prior mean and variance. + Args: params (Dict): The specific set of parameters for which the mean function should be defined for. + Returns: - Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. + Callable[[Float[Array, "N D"]], dx.Distribution]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. """ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: @@ -95,7 +133,14 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: return predict_fn def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the GP prior's parameter set""" + """Initialise the GP prior's parameter set. + + Args: + key (PRNGKeyType): The PRNG key. + + Returns: + Dict: The initialised parameter set. + """ return { "kernel": self.kernel._initialise_params(key), "mean_function": self.mean_function._initialise_params(key), @@ -116,11 +161,26 @@ class AbstractPosterior(AbstractGP): @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Predict the GP's output given the input.""" + """Predict the GP's output given the input. + + Args: + *args (Any): Arguments to the predict method. + **kwargs (Any): Keyword arguments to the predict method. + + Returns: + dx.Distribution: A multivariate normal random variable representation of the Gaussian process. + """ raise NotImplementedError def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the parameter set of a GP posterior.""" + """Initialise the parameter set of a GP posterior. + + Args: + key (PRNGKeyType): The PRNG key. + + Returns: + Dict: The initialised parameter set. + """ return concat_dictionaries( self.prior._initialise_params(key), {"likelihood": self.likelihood._initialise_params(key)}, @@ -146,7 +206,7 @@ def predict( params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -202,7 +262,7 @@ def marginal_log_likelihood( negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -241,7 +301,7 @@ class NonConjugatePosterior(AbstractPosterior): """Generic Gaussian process posterior object for models where the likelihood is non-Gaussian.""" prior: Prior - likelihood: NonConjugateLikelihoodType + likelihood: AbstractLikelihood name: Optional[str] = "Non-conjugate posterior" jitter: Optional[float] = DEFAULT_JITTER @@ -264,7 +324,7 @@ def predict( params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Float[Array, "N D"]], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, n = train_data.X, train_data.n @@ -309,7 +369,7 @@ def marginal_log_likelihood( negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -351,3 +411,13 @@ def construct_posterior( f"No posterior implemented for {likelihood.name} likelihood" ) return PosteriorGP(prior=prior, likelihood=likelihood) + + +__all__ = [ + AbstractGP, + Prior, + AbstractPosterior, + ConjugatePosterior, + NonConjugatePosterior, + construct_posterior, +] diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 5636da84..a95deb17 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Callable, Dict, List, Optional, Sequence @@ -26,7 +41,7 @@ def __post_init__(self): @abc.abstractmethod def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs. Args: @@ -106,7 +121,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return [kernel._initialise_params(key) for kernel in self.kernel_set] def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: return self.combination_fn( jnp.stack([k(x, y, p) for k, p in zip(self.kernel_set, params)]) @@ -142,7 +157,7 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` @@ -179,7 +194,7 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` @@ -215,7 +230,7 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma` @@ -257,7 +272,7 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma` @@ -301,7 +316,7 @@ def __post_init__(self): self.name = f"Polynomial Degree: {self.degree}" def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\alpha` and variance :math:`\sigma` through @@ -347,7 +362,7 @@ def __post_init__(self): self.num_vertex = self.laplacian.shape[0] def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict ) -> Float[Array, "1"]: """Evaluate the graph kernel on a pair of vertices v_i, v_j. @@ -394,7 +409,7 @@ def euclidean_distance( def gram( - kernel: Kernel, inputs: Float[Array, "N D"], params: dict + kernel: Kernel, inputs: Float[Array, "N D"], params: Dict ) -> Float[Array, "N N"]: """For a given kernel, compute the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`. @@ -410,7 +425,7 @@ def gram( def cross_covariance( - kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"], params: dict + kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"], params: Dict ) -> Float[Array, "N M"]: """For a given kernel, compute the :math:`m \times n` gram matrix on an a pair of input matrices with shape :math:`m \times d` and :math:`n \times d` for :math:`d\geq 1`. @@ -427,7 +442,7 @@ def cross_covariance( def diagonal( - kernel: Kernel, inputs: Float[Array, "N D"], params: dict + kernel: Kernel, inputs: Float[Array, "N D"], params: Dict ) -> Float[Array, "N N"]: """For a given kernel, compute the elementwise diagonal of the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`. Args: @@ -438,3 +453,21 @@ def diagonal( Array: The computed diagonal variance matrix. """ return jnp.diag(vmap(lambda x: kernel(x, x, params))(inputs)) + + +__all__ = [ + "Kernel", + "CombinationKernel", + "SumKernel", + "ProductKernel", + "RBF" "Matern12", + "Matern32", + "Matern52", + "Polynomial", + "GraphKernel", + "squared_distance", + "euclidean_distance", + "gram", + "cross_covariance", + "diagonal", +] diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 71ae16d0..a2bec329 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Any, Callable, Dict, Optional @@ -15,38 +30,65 @@ class AbstractLikelihood: """Abstract base class for likelihoods.""" - num_datapoints: int # The number of datapoints that the likelihood factorises over + num_datapoints: int # The number of datapoints that the likelihood factorises over. name: Optional[str] = "Likelihood" - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """Evaluate the likelihood function at a given predictive distribution.""" + def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + """Evaluate the likelihood function at a given predictive distribution. + + Args: + *args (Any): Arguments to be passed to the likelihood's `predict` method. + **kwargs (Any): Keyword arguments to be passed to the likelihood's `predict` method. + + Returns: + dx.Distribution: The predictive distribution. + """ return self.predict(*args, **kwargs) @abc.abstractmethod - def predict(self, *args: Any, **kwargs: Any) -> Any: - """Evaluate the likelihood function at a given predictive distribution.""" + def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: + """Evaluate the likelihood function at a given predictive distribution. + + Args: + *args (Any): Arguments to be passed to the likelihood's `predict` method. + **kwargs (Any): Keyword arguments to be passed to the likelihood's `predict` method. + + Returns: + dx.Distribution: The predictive distribution. + """ raise NotImplementedError @abc.abstractmethod def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Return the parameters of the likelihood function.""" + """Return the parameters of the likelihood function. + + Args: + key (PRNGKeyType): A PRNG key. + + Returns: + Dict: The parameters of the likelihood function. + """ raise NotImplementedError @property @abc.abstractmethod def link_function(self) -> Callable: - """Return the link function of the likelihood function.""" + """Return the link function of the likelihood function. + + Returns: + Callable: The link function of the likelihood function. + """ raise NotImplementedError @dataclass class Conjugate: - """Conjugate likelihood.""" + """An abstract class for conjugate likelihoods with respect to a Gaussain process prior.""" @dataclass class NonConjugate: - """Conjugate likelihood.""" + """An abstract class for non-conjugate likelihoods with respect to a Gaussain process prior.""" @dataclass @@ -56,7 +98,14 @@ class Gaussian(AbstractLikelihood, Conjugate): name: Optional[str] = "Gaussian" def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Return the variance parameter of the likelihood function.""" + """Return the variance parameter of the likelihood function. + + Args: + key (PRNGKeyType): A PRNG key. + + Returns: + Dict: The parameters of the likelihood function. + """ return {"obs_noise": jnp.array([1.0])} @property @@ -73,7 +122,15 @@ def link_fn(x, params: Dict) -> dx.Distribution: return link_fn def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: - """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix..""" + """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix. + + Args: + dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. + params (Dict): The parameters of the likelihood function. + + Returns: + dx.Distribution: The predictive distribution. + """ n_data = dist.event_shape[0] noisy_cov = dist.covariance() + I(n_data) * params["likelihood"]["obs_noise"] return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) @@ -84,7 +141,14 @@ class Bernoulli(AbstractLikelihood, NonConjugate): name: Optional[str] = "Bernoulli" def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the parameter set of a Bernoulli likelihood.""" + """Initialise the parameter set of a Bernoulli likelihood. + + Args: + key (PRNGKeyType): A PRNG key. + + Returns: + Dict: The parameters of the likelihood function (empty for the Bernoulli likelihood). + """ return {} @property @@ -116,16 +180,39 @@ def moment_fn( return moment_fn - def predict(self, dist: dx.Distribution, params: Dict) -> Any: + def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: + """Evaluate the pointwise predictive distribution, given a Gaussian process posterior and likelihood parameters. + + Args: + dist (dx.Distribution): The Gaussian process posterior, evaluated at a finite set of test points. + params (Dict): The parameters of the likelihood function. + + Returns: + dx.Distribution: The pointwise predictive distribution. + """ variance = jnp.diag(dist.covariance()) mean = dist.mean() return self.predictive_moment_fn(mean.ravel(), variance, params) -def inv_probit(x): +def inv_probit(x: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + """Compute the inverse probit function. + + Args: + x (Float[Array, "N 1"]): A vector of values. + + Returns: + Float[Array, "N 1"]: The inverse probit of the input vector. + """ jitter = 1e-3 # ensures output is strictly between 0 and 1 return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter -NonConjugateLikelihoods = [Bernoulli] -NonConjugateLikelihoodType = Bernoulli # Union[Bernoulli] +__all__ = [ + "AbstractLikelihood", + "Conjugate", + "NonConjugate", + "Gaussian", + "Bernoulli", + "inv_probit", +] diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 960bface..9ab21d14 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Dict, Optional @@ -20,10 +35,10 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. This method is required for all subclasses. Args: - x (Array): The input points at which to evaluate the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. Returns: - Array: The mean function evaluated point-wise on the inputs. + Float[Array, "N Q"]: The mean function evaluated point-wise on the inputs. """ raise NotImplementedError @@ -31,6 +46,9 @@ def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. + Args: + key (PRNGKeyType): The PRNG key to use for initialising the parameters. + Returns: Dict: The parameters of the mean function. """ @@ -46,21 +64,28 @@ class Zero(AbstractMeanFunction): output_dim: Optional[int] = 1 name: Optional[str] = "Zero mean function" - def __call__(self, x: Float[Array, "N D"], params: dict) -> Float[Array, "N Q"]: + def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. Args: - x (Array): The input points at which to evaluate the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. params (Dict): The parameters of the mean function. Returns: - Array: A vector of zeros. + Float[Array, "N Q"]: A vector of zeros. """ out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) def _initialise_params(self, key: PRNGKeyType) -> Dict: - """The parameters of the mean function. For the zero-mean function, this is an empty dictionary.""" + """The parameters of the mean function. For the zero-mean function, this is an empty dictionary. + + Args: + key (PRNGKeyType): The PRNG key to use for initialising the parameters. + + Returns: + Dict: The parameters of the mean function. + """ return {} @@ -78,15 +103,29 @@ def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. Args: - x (Array): The input points at which to evaluate the mean function. + x (Float[Array, "N D"]): The input points at which to evaluate the mean function. params (Dict): The parameters of the mean function. Returns: - Array: A vector of repeated constant values. + Float[Array, "N Q"]: A vector of repeated constant values. """ out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] def _initialise_params(self, key: PRNGKeyType) -> Dict: - """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value.""" + """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value. + + Args: + key (PRNGKeyType): The PRNG key to use for initialising the parameters. + + Returns: + Dict: The parameters of the mean function. + """ return {"constant": jnp.array([1.0])} + + +__all__ = [ + "AbstractMeanFunction", + "Zero", + "Constant", +] diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index e3b79ce3..ba2bd3f1 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,15 +1,29 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from copy import deepcopy from typing import Callable, Dict, Tuple import jax.numpy as jnp -import jax.random as jr import jax.scipy as jsp from jax import value_and_grad from jaxtyping import Array, Float from .config import get_defaults from .gps import AbstractPosterior -from .parameters import build_trainables, constrain, trainable_params, unconstrain +from .parameters import build_trainables, constrain, trainable_params from .types import Dataset from .utils import I from .variational_families import ( @@ -41,6 +55,7 @@ def natural_to_expectation(params: Dict, jitter: float = DEFAULT_JITTER) -> Dict Args: params: A dictionary of variational Gaussian parameters under the natural parameterisation. jitter (float): A small value to prevent numerical instability. + Returns: Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ @@ -86,11 +101,13 @@ def _expectation_elbo( ) -> Callable[[Dict, Dataset], float]: """ Construct evidence lower bound (ELBO) for variational Gaussian under the expectation parameterisation. + Args: posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. + Returns: - Callable: A function that computes the ELBO. + Callable[[Dict, Dataset], float]: A function that computes the ELBO. """ expectation_vartiational_gaussian = ExpectationVariationalGaussian( prior=variational_family.prior, @@ -104,7 +121,14 @@ def _expectation_elbo( def _rename_expectation_to_natural(params: Dict) -> Dict: - """This function renames the gradient components (that have expectation parameterisation keys) to match the natural parameterisation pytree.""" + """This function renames the gradient components (that have expectation parameterisation keys) to match the natural parameterisation pytree. + + Args: + params (Dict): A dictionary of variational Gaussian parameters under the expectation parameterisation moment names. + + Returns: + Dict: A dictionary of variational Gaussian parameters under the natural parameterisation moment names. + """ params["variational_family"]["moments"] = { "natural_vector": params["variational_family"]["moments"]["expectation_vector"], "natural_matrix": params["variational_family"]["moments"]["expectation_matrix"], @@ -114,7 +138,14 @@ def _rename_expectation_to_natural(params: Dict) -> Dict: def _rename_natural_to_expectation(params: Dict) -> Dict: - """This function renames the gradient components (that have natural parameterisation keys) to match the expectation parameterisation pytree.""" + """This function renames the gradient components (that have natural parameterisation keys) to match the expectation parameterisation pytree. + + Args: + params (Dict): A dictionary of variational Gaussian parameters under the natural parameterisation moment names. + + Returns: + Dict: A dictionary of variational Gaussian parameters under the expectation parameterisation moment names. + """ params["variational_family"]["moments"] = { "expectation_vector": params["variational_family"]["moments"]["natural_vector"], "expectation_matrix": params["variational_family"]["moments"]["natural_matrix"], @@ -123,6 +154,41 @@ def _rename_natural_to_expectation(params: Dict) -> Dict: return params +def _get_moment_trainables(trainables: Dict) -> Dict: + """This function takes a trainbles dictionary, and sets non-moment parameter training to false for gradient stopping. + + Args: + trainables (Dict): A dictionary of trainables. + + Returns: + Dict: A dictionary of trainables with non-moment parameters set to False. + """ + expectation_trainables = _rename_natural_to_expectation(deepcopy(trainables)) + moment_trainables = build_trainables(expectation_trainables, False) + moment_trainables["variational_family"]["moments"] = expectation_trainables[ + "variational_family" + ]["moments"] + + return moment_trainables + + +def _get_hyperparameter_trainables(trainables: Dict) -> Dict: + """This function takes a trainbles dictionary, and sets moment parameter training to false for gradient stopping. + + Args: + trainables (Dict): A dictionary of trainables. + + Returns: + Dict: A dictionary of trainables with moment parameters set to False. + """ + hyper_trainables = deepcopy(trainables) + hyper_trainables["variational_family"]["moments"] = build_trainables( + trainables["variational_family"]["moments"], False + ) + + return hyper_trainables + + def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, @@ -131,11 +197,13 @@ def natural_gradients( ) -> Tuple[Callable[[Dict, Dataset], Dict]]: """ Computes the gradient with respect to the natural parameters. Currently only implemented for the natural variational Gaussian family. + Args: posterior: An instance of AbstractPosterior. variational_family: An instance of AbstractVariationalFamily. train_data: A Dataset. bijectors: A dictionary of bijectors. + Returns: Tuple[Callable[[Dict, Dataset], Dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. """ @@ -148,28 +216,20 @@ def natural_gradients( # The ELBO under the expectation parameterisation, L(η). expectation_elbo = _expectation_elbo(posterior, variational_family, train_data) - # Stop nonment params: - expectation_trainables = _rename_natural_to_expectation(deepcopy(trainables)) - moment_trainables = build_trainables(expectation_trainables, False) - moment_trainables["variational_family"]["moments"] = expectation_trainables[ - "variational_family" - ]["moments"] - - # Stop moment params: - hyper_trainables = deepcopy(trainables) - hyper_trainables["variational_family"]["moments"] = build_trainables( - trainables["variational_family"]["moments"], False - ) + # Trainable dictionaries for alternating gradient updates. + moment_trainables = _get_moment_trainables(trainables) + hyper_trainables = _get_hyperparameter_trainables(trainables) if isinstance(variational_family, NaturalVariationalGaussian): def nat_grads_fn(params: Dict, batch: Dataset) -> Dict: """ Computes the natural gradients of the ELBO. + Args: - params: A dictionary of parameters. - trainables: A dictionary of trainables. - batch: A Dataset. + params (Dict): A dictionary of parameters. + batch (Dataset): A Dataset. + Returns: Dict: A dictionary of natural gradients. """ @@ -198,10 +258,11 @@ def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: def hyper_grads_fn(params: Dict, batch: Dataset) -> Dict: """ Computes the hyperparameter gradients of the ELBO. + Args: - params: A dictionary of parameters. - trainables: A dictionary of trainables. - batch: A Dataset. + params (Dict): A dictionary of parameters. + batch (Dataset): A Dataset. + Returns: Dict: A dictionary of hyperparameter gradients. """ @@ -218,3 +279,9 @@ def loss_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: return value, dL_dhyper return nat_grads_fn, hyper_grads_fn + + +__all__ = [ + "natural_to_expectation", + "natural_gradients", +] diff --git a/gpjax/parameters.py b/gpjax/parameters.py index c6c8fb04..0462a56b 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import warnings from copy import deepcopy from typing import Dict, Tuple @@ -29,27 +44,43 @@ class ParameterState: bijectors: Dict def unpack(self): + """Unpack the state into a tuple of parameters, trainables and bijectors. + + Returns: + Tuple[Dict, Dict, Dict]: The parameters, trainables and bijectors. + """ return self.params, self.trainables, self.bijectors def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: - """Initialise the stateful parameters of any GPJax object. This function also returns the trainability status of each parameter and set of bijectors that allow parameters to be constrained and unconstrained.""" + """Initialise the stateful parameters of any GPJax object. This function also returns the trainability status of each parameter and set of bijectors that allow parameters to be constrained and unconstrained. + + Args: + model: The GPJax object that is to be initialised. + key (PRNGKeyType, optional): The random key that is to be used for initialisation. Defaults to None. + + Returns: + ParameterState: The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained. + """ + if key is None: warn("No PRNGKey specified. Defaulting to seed 123.", UserWarning, stacklevel=2) key = jr.PRNGKey(123) params = model._initialise_params(key) + if kwargs: _validate_kwargs(kwargs, params) for k, v in kwargs.items(): params[k] = merge_dictionaries(params[k], v) + bijectors = build_bijectors(params) trainables = build_trainables(params) - state = ParameterState( + + return ParameterState( params=params, trainables=trainables, bijectors=bijectors, ) - return state def _validate_kwargs(kwargs, params): @@ -144,8 +175,7 @@ def constrain(params: Dict, bijectors: Dict) -> Dict: Args: params (Dict): The parameters that are to be transformed. - transform_map (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + bijectors (Dict): The bijectors that are to be used for transformation. Returns: Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. @@ -161,8 +191,7 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: Args: params (Dict): The parameters that are to be transformed. - transform_map (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. - foward (bool): Whether the parameters should be constrained (foward=True) or unconstrained (foward=False). + bijectors (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. Returns: Dict: A transformed parameter set. The dictionary is equal in structure to the input params dictionary. @@ -172,14 +201,6 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: return jax.tree_util.tree_map(map, params, bijectors) - if transform_map is None: - return params - - else: - return jax.tree_util.tree_map( - lambda param, trans: trans(param), params, transform_map - ) - ################################ # Priors @@ -187,6 +208,15 @@ def unconstrain(params: Dict, bijectors: Dict) -> Dict: def log_density( param: Float[Array, "D"], density: dx.Distribution ) -> Float[Array, "1"]: + """Compute the log density of a parameter given a distribution. + + Args: + param (Float[Array, "D"]): The parameter that is to be evaluated. + density (dx.Distribution): The distribution that is to be evaluated. + + Returns: + Float[Array, "1"]: The log density of the parameter. + """ if type(density) == type(None): log_prob = jnp.array(0.0) else: @@ -195,6 +225,14 @@ def log_density( def copy_dict_structure(params: Dict) -> Dict: + """Copy the structure of a dictionary. + + Args: + params (Dict): The dictionary that is to be copied. + + Returns: + Dict: A copy of the input dictionary. + """ # Copy dictionary structure prior_container = deepcopy(params) # Set all values to zero @@ -228,7 +266,8 @@ def evaluate_priors(params: Dict, priors: Dict) -> Dict: estimates. priors (Dict): Dictionary specifying the parameters' prior distributions. - Returns: Array: The log-prior density, summed over all parameters. + Returns: + Dict: The log-prior density, summed over all parameters. """ lpd = jnp.array(0.0) if priors is not None: @@ -238,7 +277,14 @@ def evaluate_priors(params: Dict, priors: Dict) -> Dict: def prior_checks(priors: Dict) -> Dict: - """Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian.""" + """Run checks on the parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian. + + Args: + priors (Dict): Dictionary specifying the parameters' prior distributions. + + Returns: + Dict: Dictionary specifying the parameters' prior distributions. + """ if "latent" in priors.keys(): latent_prior = priors["latent"] if latent_prior is not None: @@ -274,13 +320,47 @@ def build_trainables(params: Dict, status: bool = True) -> Dict: return prior_container -def stop_grad(param: Dict, trainable: Dict): - """When taking a gradient, we want to stop the gradient from flowing through a parameter if it is not trainable. This is achieved using the model's dictionary of parameters and the corresponding trainability status.""" +def _stop_grad(param: Dict, trainable: Dict) -> Dict: + """When taking a gradient, we want to stop the gradient from flowing through a parameter if it is not trainable. This is achieved using the model's dictionary of parameters and the corresponding trainability status. + + Args: + param (Dict): The parameter set for which trainable statuses should be derived from. + trainable (Dict): A boolean value denoting the training status the `param`. + + Returns: + Dict: The gradient is stopped for non-trainable parameters. + """ return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) def trainable_params(params: Dict, trainables: Dict) -> Dict: - """Stop the gradients flowing through parameters whose trainable status is False""" + """Stop the gradients flowing through parameters whose trainable status is False. + + Args: + params (Dict): The parameter set for which trainable statuses should be derived from. + trainables (Dict): A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + + Returns: + Dict: A dictionary parameters. The dictionary is equal in structure to the input params dictionary. + """ return jax.tree_util.tree_map( - lambda param, trainable: stop_grad(param, trainable), params, trainables + lambda param, trainable: _stop_grad(param, trainable), params, trainables ) + + +__all__ = [ + "ParameterState", + "initialise", + "recursive_items", + "recursive_complete", + "build_bijectors", + "constrain", + "unconstrain", + "log_density", + "copy_dict_structure", + "structure_priors", + "evaluate_priors", + "prior_checks", + "build_trainables", + "trainable_params", +] diff --git a/gpjax/quadrature.py b/gpjax/quadrature.py index ff1b2ee3..8312f6e5 100644 --- a/gpjax/quadrature.py +++ b/gpjax/quadrature.py @@ -1,4 +1,19 @@ -from typing import Callable +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, Optional import jax.numpy as jnp import numpy as np @@ -12,23 +27,26 @@ def gauss_hermite_quadrature( fun: Callable, mean: Float[Array, "N D"], var: Float[Array, "N D"], - deg: int = DEFAULT_NUM_GAUSS_HERMITE_POINTS, + deg: Optional[int] = DEFAULT_NUM_GAUSS_HERMITE_POINTS, *args, **kwargs -) -> Float[Array, "D"]: +) -> Float[Array, "N"]: """Compute Gaussian-Hermite quadrature for a given function. The quadrature points are adjusted through the supplied mean and variance arrays. Args: fun (Callable): The function for which quadrature should be applied to. - mean (Array): The mean of the Gaussian distribution that is used to shift quadrature points. - var (Array): The variance of the Gaussian distribution that is used to scale quadrature points. + mean (Float[Array, "N D"]): The mean of the Gaussian distribution that is used to shift quadrature points. + var (Float[Array, "N D"]): The variance of the Gaussian distribution that is used to scale quadrature points. deg (int, optional): The number of quadrature points that are to be used. Defaults to 20. Returns: - Array: The evaluated integrals value. + Float[Array, "N"]: The evaluated integrals value. """ gh_points, gh_weights = np.polynomial.hermite.hermgauss(deg) stdev = jnp.sqrt(var) X = mean + jnp.sqrt(2.0) * stdev * gh_points W = gh_weights / jnp.sqrt(jnp.pi) return jnp.sum(fun(X, *args, **kwargs) * W, axis=1) + + +__all__ = ["gauss_hermite_quadrature"] diff --git a/gpjax/types.py b/gpjax/types.py index 9cae03ca..f97cab8b 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import jax.numpy as jnp from chex import dataclass from jaxtyping import Array, Float diff --git a/gpjax/utils.py b/gpjax/utils.py index 0b1c0055..d7120626 100644 --- a/gpjax/utils.py +++ b/gpjax/utils.py @@ -1,18 +1,34 @@ -from copy import deepcopy +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + from typing import Callable, Dict, Tuple import jax import jax.numpy as jnp -import jax.random as jr -from chex import PRNGKey from jaxtyping import Array, Float def I(n: int) -> Float[Array, "N N"]: """ Compute an n x n identity matrix. - :param n: The size of of the matrix. - :return: An n x n identity matrix. + + Args: + n (int): The size of of the matrix. + + Returns: + Float[Array, "N N"]: An n x n identity matrix. """ return jnp.eye(n) @@ -21,6 +37,13 @@ def concat_dictionaries(a: Dict, b: Dict) -> Dict: """ Append one dictionary below another. If duplicate keys exist, then the key-value pair of the second supplied dictionary will be used. + + Args: + a (Dict): The first dictionary. + b (Dict): The second dictionary. + + Returns: + Dict: The merged dictionary. """ return {**a, **b} @@ -32,11 +55,14 @@ def merge_dictionaries(base_dict: Dict, in_dict: Dict) -> Dict: the base_dict will be a complete dictionary of values such that an incomplete second dictionary can be used to update specific key-value pairs. - :param base_dict: Complete dictionary of key-value pairs. - :param in_dict: Subset of key-values pairs such that values from this dictionary will take precedent. - :return: A merged single dictionary. + Args: + base_dict (Dict): Complete dictionary of key-value pairs. + in_dict (Dict): Subset of key-values pairs such that values from this dictionary will take precedent. + + Returns: + Dict: A dictionary with the same keys as the base_dict, but with values from the in_dict. """ - for k, v in base_dict.items(): + for k, _ in base_dict.items(): if k in in_dict.keys(): base_dict[k] = in_dict[k] return base_dict @@ -46,21 +72,15 @@ def sort_dictionary(base_dict: Dict) -> Dict: """ Sort a dictionary based on the dictionary's key values. - :param base_dict: The unsorted dictionary. - :return: A dictionary sorted alphabetically on the dictionary's keys. + Args: + base_dict (Dict): The dictionary to be sorted. + + Returns: + Dict: The dictionary sorted alphabetically on the dictionary's keys. """ return dict(sorted(base_dict.items())) -def as_constant(parameter_set: Dict, params: list) -> Tuple[Dict, Dict]: - base_params = deepcopy(parameter_set) - sparams = {} - for param in params: - sparams[param] = base_params[param] - del base_params[param] - return base_params, sparams - - def dict_array_coercion(params: Dict) -> Tuple[Callable, Callable]: """Construct the logic required to map a dictionary of parameters to an array of parameters. The values of the dictionary can themselves be dictionaries; the function should work recursively. @@ -79,3 +99,12 @@ def array_to_dict(parameter_array) -> Dict: return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array) return dict_to_array, array_to_dict + + +__all__ = [ + "I", + "concat_dictionaries", + "merge_dictionaries", + "sort_dictionary", + "dict_array_coercion", +] diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 24371e2a..5e149e31 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Any, Callable, Dict, Optional @@ -21,18 +36,41 @@ class AbstractVariationalFamily: """Abstract base class used to represent families of distributions that can be used within variational inference.""" - def __call__(self, *args: Any, **kwargs: Any) -> Any: - """For a given set of parameters, compute the latent function's prediction under the variational approximation.""" + def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: + """For a given set of parameters, compute the latent function's prediction under the variational approximation. + + Args: + *args (Any): Arguments of the variational family's `predict` method. + **kwargs (Any): Keyword arguments of the variational family's `predict` method. + + Returns: + Any: The output of the variational family's `predict` method. + """ return self.predict(*args, **kwargs) @abc.abstractmethod def _initialise_params(self, key: PRNGKeyType) -> Dict: - """The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix.""" + """The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix. + + Args: + key (PRNGKeyType): The PRNG key used to initialise the parameters. + + Returns: + Dict: The parameters of the distribution. + """ raise NotImplementedError @abc.abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: - """Predict the GP's output given the input.""" + """Predict the GP's output given the input. + + Args: + *args (Any): Arguments of the variational family's `predict` method. + **kwargs (Any): Keyword arguments of the variational family's `predict` method. + + Returns: + Any: The output of the variational family's `predict` method. + """ raise NotImplementedError @@ -60,7 +98,14 @@ class VariationalGaussian(AbstractVariationalGaussian): """ def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" + """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution. + + Args: + key (PRNGKeyType): The PRNG key used to initialise the parameters. + + Returns: + Dict: The parameters of the distribution. + """ m = self.num_inducing return concat_dictionaries( @@ -102,7 +147,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as: @@ -190,7 +235,7 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -564,7 +609,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ) def predict( - self, train_data: Dataset, params: dict + self, train_data: Dataset, params: Dict ) -> Callable[[Float[Array, "N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: @@ -635,3 +680,14 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: ) return predict_fn + + +__all__ = [ + "AbstractVariationalFamily", + "AbstractVariationalGaussian", + "VariationalGaussian", + "WhitenedVariationalGaussian", + "NaturalVariationalGaussian", + "ExpectationVariationalGaussian", + "CollapsedVariationalGaussian", +] diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 00e3740c..16e1e604 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -1,3 +1,18 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + import abc from typing import Callable, Dict @@ -8,7 +23,7 @@ from jaxtyping import Array, Float from .gps import AbstractPosterior -from .kernels import cross_covariance, diagonal, gram +from .kernels import cross_covariance, gram from .likelihoods import Gaussian from .quadrature import gauss_hermite_quadrature from .types import Dataset @@ -139,7 +154,7 @@ def __post_init__(self): def elbo( self, train_data: Dataset, negative: bool = False - ) -> Callable[[dict], Float[Array, "1"]]: + ) -> Callable[[Dict], Float[Array, "1"]]: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. Args: @@ -226,3 +241,10 @@ def elbo_fn(params: Dict) -> Float[Array, "1"]: return constant * (two_log_prob - two_trace).squeeze() / 2.0 return elbo_fn + + +__all__ = [ + "AbstractVariationalInference", + "StochasticVI", + "CollapsedVI", +] diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index 50259852..c910f9ab 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -98,7 +98,7 @@ def test_natural_gradients(ndata, nb, n_iters): q = gpx.NaturalVariationalGaussian(prior=prior, inducing_inputs=z) svgp = gpx.StochasticVI(posterior=p, variational_family=q) - params, trainable_status, bijectors = initialise(svgp, key).unpack() + training_state = initialise(svgp, key) D = Dataset(X=x, y=y) @@ -108,9 +108,7 @@ def test_natural_gradients(ndata, nb, n_iters): key = jr.PRNGKey(42) inference_state = fit_natgrads( svgp, - params, - trainable_status, - bijectors, + training_state, D, moment_optimiser, hyper_optimiser, diff --git a/tests/test_gp.py b/tests/test_gp.py index 2ea2e780..5a940f53 100644 --- a/tests/test_gp.py +++ b/tests/test_gp.py @@ -14,9 +14,11 @@ construct_posterior, ) from gpjax.kernels import RBF, Matern12, Matern32, Matern52 -from gpjax.likelihoods import Bernoulli, Gaussian, NonConjugateLikelihoods +from gpjax.likelihoods import Bernoulli, Gaussian from gpjax.parameters import ParameterState +NonConjugateLikelihoods = [Bernoulli] + @pytest.mark.parametrize("num_datapoints", [1, 10]) def test_prior(num_datapoints): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index 012720ba..bcd0d776 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -3,7 +3,6 @@ from gpjax.utils import ( I, - as_constant, concat_dictionaries, dict_array_coercion, merge_dictionaries, @@ -41,16 +40,6 @@ def test_sort_dict(): assert list(sorted_dict.values()) == [2, 1] -def test_as_constant(): - base = {"a": 1, "b": 2, "c": 3} - b1, s1 = as_constant(base, ["a"]) - b2, s2 = as_constant(base, ["a", "b"]) - assert list(b1.keys()) == ["b", "c"] - assert list(s1.keys()) == ["a"] - assert list(b2.keys()) == ["c"] - assert list(s2.keys()) == ["a", "b"] - - @pytest.mark.parametrize("d", [1, 2, 10]) def test_array_coercion(d): params = { From 566f654557fce6fc5e8a8e41a0af90974e81b4ad Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 20 Sep 2022 15:47:24 +0100 Subject: [PATCH 65/66] Complete rebase. --- gpjax/abstractions.py | 5 ----- tests/test_abstractions.py | 6 +++++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index b332090b..16747738 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -118,11 +118,6 @@ def fit( optax_optim: ox.GradientTransformation, n_iters: Optional[int] = 100, log_rate: Optional[int] = 10, - objective: tp.Callable, - parameter_state: ParameterState, - optax_optim, - n_iters: int = 100, - log_rate: int = 10, ) -> InferenceState: """Abstracted method for fitting a GP model with respect to a supplied objective function. Optimisers used here should originate from Optax. diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index b4899034..a0b7fabf 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -63,7 +63,7 @@ def test_batch_fitting(n_iters, nb, ndata): svgp = gpx.StochasticVI(posterior=p, variational_family=q) parameter_state = initialise(svgp, key) - objective = svgp.elbo(D) + objective = svgp.elbo(D, negative=True) pre_mll_val = objective(parameter_state.params, D) @@ -105,6 +105,10 @@ def test_natural_gradients(ndata, nb, n_iters): hyper_optimiser = optax.adam(learning_rate=0.1) moment_optimiser = optax.sgd(learning_rate=1.0) + objective = svgp.elbo(D, negative=True) + parameter_state = initialise(svgp, key) + pre_mll_val = objective(parameter_state.params, D) + key = jr.PRNGKey(42) inference_state = fit_natgrads( svgp, From 24032c02700db000b6c5e8ed88b40d7066f55bcd Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Tue, 20 Sep 2022 17:04:35 +0100 Subject: [PATCH 66/66] One Cholesky is better than two. --- examples/natgrads.ipynb | 6 ++--- gpjax/variational_families.py | 41 +++++++++++++++++++++-------------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/natgrads.ipynb b/examples/natgrads.ipynb index 35cdc7cb..b86fe975 100644 --- a/examples/natgrads.ipynb +++ b/examples/natgrads.ipynb @@ -87,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "z = jnp.linspace(-5.0, 5.0, 5000).reshape(-1, 1)\n", + "z = jnp.linspace(-5.0, 5.0, 20).reshape(-1, 1)\n", "\n", "fig, ax = plt.subplots(figsize=(12, 5))\n", "ax.plot(x, y, \"o\", alpha=0.3)\n", @@ -151,8 +151,8 @@ "inference_state = gpx.fit_natgrads(natural_svgp,\n", " parameter_state=parameter_state,\n", " train_data = D,\n", - " n_iters = 4000,\n", - " batch_size=128,\n", + " n_iters = 5000,\n", + " batch_size=100,\n", " key = jr.PRNGKey(42),\n", " moment_optim = ox.sgd(1.0),\n", " hyper_optim = ox.adam(1e-3),\n", diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 5e149e31..0a939fd8 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -337,16 +337,23 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: z = params["variational_family"]["inducing_inputs"] m = self.num_inducing + # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter - L_inv = jnp.linalg.cholesky(S_inv) - C = jsp.linalg.solve_triangular(L_inv, I(m), lower=True) - S = jnp.matmul(C.T, C) - mu = jnp.matmul(S, natural_vector) + # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: + sqrt_inv = jnp.swapaxes( + jnp.linalg.cholesky(S_inv[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1 + ) - S += I(m) * self.jitter - sqrt = jnp.linalg.cholesky(S) + # L = (L⁻¹)⁻¹I + sqrt = jsp.linalg.solve_triangular(sqrt_inv, I(m), lower=True) + + # S = LLᵀ: + S = jnp.matmul(sqrt, sqrt.T) + + # μ = Sθ₁ + mu = jnp.matmul(S, natural_vector) μz = self.prior.mean_function(z, params["mean_function"]) Kzz = gram(self.prior.kernel, z, params["kernel"]) @@ -382,14 +389,16 @@ def predict(self, params: Dict) -> Callable[[Float[Array, "N D"]], dx.Distributi S_inv = -2 * natural_matrix S_inv += I(m) * self.jitter - # S⁻¹ = LLᵀ - L = jnp.linalg.cholesky(S_inv) + # Compute L⁻¹, where LLᵀ = S, via a trick found in the NumPyro source code and https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril: + sqrt_inv = jnp.swapaxes( + jnp.linalg.cholesky(S_inv[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1 + ) - # C = L⁻¹I - C = jsp.linalg.solve_triangular(L, I(m), lower=True) + # L = (L⁻¹)⁻¹I + sqrt = jsp.linalg.solve_triangular(sqrt_inv, I(m), lower=True) - # S = CᵀC - S = jnp.matmul(C.T, C) + # S = LLᵀ: + S = jnp.matmul(sqrt, sqrt.T) # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) @@ -411,17 +420,17 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: # Kzz⁻¹ Kzt Kzz_inv_Kzt = jsp.linalg.solve_triangular(Lz.T, Lz_inv_Kzt, lower=False) - # Ktz Kzz⁻¹ Cᵀ - Ktz_Kzz_inv_CT = jnp.matmul(Kzz_inv_Kzt.T, C.T) + # Ktz Kzz⁻¹ L + Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt) # μt + Ktz Kzz⁻¹ (μ - μz) mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = CᵀC] + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = LLᵀ] covariance = ( Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) - + jnp.matmul(Ktz_Kzz_inv_CT, Ktz_Kzz_inv_CT.T) + + jnp.matmul(Ktz_Kzz_inv_L, Ktz_Kzz_inv_L.T) ) return dx.MultivariateNormalFullCovariance(