From 9ee096a36772dd5a8f2de020841a9603db439bb3 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Sun, 6 Nov 2022 18:36:43 +0000 Subject: [PATCH] Improve readability and add comments. --- gpjax/abstractions.py | 220 ++++++++++++++++++++------------- gpjax/config.py | 4 +- gpjax/gps.py | 183 +++++++++++++++++---------- gpjax/kernels.py | 22 ++-- gpjax/mean_functions.py | 3 +- gpjax/variational_families.py | 200 ++++++++++++++++++++---------- gpjax/variational_inference.py | 57 +++++---- 7 files changed, 429 insertions(+), 260 deletions(-) diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index cf50d06c..d8d610fe 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Tuple, Any, Union import jax import jax.numpy as jnp @@ -33,83 +33,18 @@ @dataclass(frozen=True) class InferenceState: + """Imutable dataclass for storing optimised parameters and training history.""" + params: Dict history: Float[Array, "n_iters"] - def unpack(self): - return self.params, self.history - - -def progress_bar_scan(n_iters: int, log_rate: int): - """Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).""" - - tqdm_bars = {} - remainder = n_iters % log_rate - - def _define_tqdm(args, transform): - tqdm_bars[0] = tqdm(range(n_iters)) - - def _update_tqdm(args, transform): - loss_val, arg = args - tqdm_bars[0].update(arg) - tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"}) - - def _update_progress_bar(loss_val, i): - """Updates tqdm progress bar of a JAX scan or loop.""" - _ = lax.cond( - i == 0, - lambda _: host_callback.id_tap(_define_tqdm, None, result=i), - lambda _: i, - operand=None, - ) - - _ = lax.cond( - # update tqdm every multiple of `print_rate` except at the end - (i % log_rate == 0) & (i != n_iters - remainder), - lambda _: host_callback.id_tap( - _update_tqdm, (loss_val, log_rate), result=i - ), - lambda _: i, - operand=None, - ) - - _ = lax.cond( - # update tqdm by `remainder` - i == n_iters - remainder, - lambda _: host_callback.id_tap( - _update_tqdm, (loss_val, remainder), result=i - ), - lambda _: i, - operand=None, - ) - - def _close_tqdm(args, transform): - tqdm_bars[0].close() - - def close_tqdm(result, i): - return lax.cond( - i == n_iters - 1, - lambda _: host_callback.id_tap(_close_tqdm, None, result=result), - lambda _: result, - operand=None, - ) - - def _progress_bar_scan(func): - """Decorator that adds a progress bar to `body_fun` used in `lax.scan`.""" - - def wrapper_progress_bar(carry, x): - if type(x) is tuple: - iter_num, *_ = x - else: - iter_num = x - result = func(carry, x) - *_, loss_val = result - _update_progress_bar(loss_val, iter_num) - return close_tqdm(result, iter_num) - - return wrapper_progress_bar + def unpack(self) -> Tuple[Dict, Float[Array, "n_iters"]]: + """Unpack parameters and training history into a tuple. - return _progress_bar_scan + Returns: + Tuple[Dict, Float[Array, "n_iters"]]: Tuple of parameters and training history. + """ + return self.params, self.history def fit( @@ -137,18 +72,23 @@ def fit( params, trainables, bijectors = parameter_state.unpack() - def loss(params): + # Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False + def loss(params: Dict) -> Float[Array, "1"]: params = trainable_params(params, trainables) params = constrain(params, bijectors) return objective(params) - iter_nums = jnp.arange(n_iters) - - # Tranform params to unconstrained space: + # Tranform params to unconstrained space params = unconstrain(params, bijectors) + + # Initialise optimiser state opt_state = optax_optim.init(params) - def step(carry, iter_num): + # Iteration loop numbers to scan over + iter_nums = jnp.arange(n_iters) + + # Optimisation step + def step(carry, iter_num: int): params, opt_state = carry loss_val, loss_gradient = jax.value_and_grad(loss)(params) updates, opt_state = optax_optim.update(loss_gradient, opt_state, params) @@ -156,17 +96,17 @@ def step(carry, iter_num): carry = params, opt_state return carry, loss_val + # Display progress bar if verbose is True if verbose: step = progress_bar_scan(n_iters, log_rate)(step) + # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) - # Tranform params to constrained space: + # Tranform final params to constrained space params = constrain(params, bijectors) - inf_state = InferenceState(params=params, history=history) - - return inf_state + return InferenceState(params=params, history=history) def fit_batches( @@ -200,17 +140,23 @@ def fit_batches( params, trainables, bijectors = parameter_state.unpack() - def loss(params, batch): + # Define optimisation loss function on unconstrained space, with a stop gradient rule for trainables that are set to False + def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]: params = trainable_params(params, trainables) params = constrain(params, bijectors) return objective(params, batch) + # Tranform params to unconstrained space params = unconstrain(params, bijectors) + # Initialise optimiser state opt_state = optax_optim.init(params) + + # Mini-batch random keys and iteration loop numbers to scan over keys = jr.split(key, n_iters) iter_nums = jnp.arange(n_iters) + # Optimisation step def step(carry, iter_num__and__key): iter_num, key = iter_num__and__key params, opt_state = carry @@ -224,19 +170,21 @@ def step(carry, iter_num__and__key): carry = params, opt_state return carry, loss_val + # Display progress bar if verbose is True if verbose: step = progress_bar_scan(n_iters, log_rate)(step) + # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) + # Tranform final params to constrained space params = constrain(params, bijectors) - inf_state = InferenceState(params=params, history=history) - return inf_state + return InferenceState(params=params, history=history) def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: - """Batch the data into mini-batches. + """Batch the data into mini-batches. Sampling is done with replacement. Args: train_data (Dataset): The training dataset. @@ -247,6 +195,7 @@ def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset """ x, y, n = train_data.X, train_data.y, train_data.n + # Subsample data inidicies with replacement to get the mini-batch indicies = jr.choice(key, n, (batch_size,), replace=True) return Dataset(X=x[indicies], y=y[indicies]) @@ -285,18 +234,23 @@ def fit_natgrads( params, trainables, bijectors = parameter_state.unpack() + # Tranform params to unconstrained space params = unconstrain(params, bijectors) + # Initialise optimiser states hyper_state = hyper_optim.init(params) moment_state = moment_optim.init(params) + # Build natural and hyperparameter gradient functions nat_grads_fn, hyper_grads_fn = natural_gradients( stochastic_vi, train_data, bijectors, trainables ) + # Mini-batch random keys and iteration loop numbers to scan over keys = jax.random.split(key, n_iters) iter_nums = jnp.arange(n_iters) + # Optimisation step def step(carry, iter_num__and__key): iter_num, key = iter_num__and__key params, hyper_state, moment_state = carry @@ -316,15 +270,103 @@ def step(carry, iter_num__and__key): carry = params, hyper_state, moment_state return carry, loss_val + # Display progress bar if verbose is True if verbose: step = progress_bar_scan(n_iters, log_rate)(step) + # Run the optimisation loop (params, _, _), history = jax.lax.scan( step, (params, hyper_state, moment_state), (iter_nums, keys) ) + + # Tranform final params to constrained space params = constrain(params, bijectors) - inf_state = InferenceState(params=params, history=history) - return inf_state + + return InferenceState(params=params, history=history) + + +def progress_bar_scan(n_iters: int, log_rate: int) -> Callable: + """Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).""" + + tqdm_bars = {} + remainder = n_iters % log_rate + + def _define_tqdm(args: Any, transform: Any) -> None: + """Define a tqdm progress bar.""" + tqdm_bars[0] = tqdm(range(n_iters)) + + def _update_tqdm(args: Any, transform: Any) -> None: + """Update the tqdm progress bar with the latest objective value.""" + loss_val, arg = args + tqdm_bars[0].update(arg) + tqdm_bars[0].set_postfix({"Objective": f"{loss_val: .2f}"}) + + def _close_tqdm(args: Any, transform: Any) -> None: + """Close the tqdm progress bar.""" + tqdm_bars[0].close() + + def _callback(cond:bool, func: Callable, arg: Any) -> None: + """Callback a function for a given argument if a condition is true.""" + dummy_result = 0 + + def _do_callback(_) -> int: + """Perform the callback.""" + return host_callback.id_tap(func, arg, result=dummy_result) + + def _not_callback(_) -> int: + """Do nothing.""" + return dummy_result + + _ = lax.cond(cond, _do_callback, _not_callback, operand=None) + + + def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None: + """Updates tqdm progress bar of a JAX scan or loop.""" + + # Conditions for iteration number + is_first: bool = iter_num == 0 + is_multiple: bool = (iter_num % log_rate == 0) & (iter_num != n_iters - remainder) + is_remainder: bool = iter_num == n_iters - remainder + is_last: bool = iter_num == n_iters - 1 + + # Define progress bar, if first iteration + _callback(is_first, _define_tqdm, None) + + # Update progress bar, if multiple of log_rate + _callback(is_multiple, _update_tqdm, (loss_val, log_rate)) + + # Update progress bar, if remainder + _callback(is_remainder, _update_tqdm, (loss_val, remainder)) + + # Close progress bar, if last iteration + _callback(is_last, _close_tqdm, None) + + + def _progress_bar_scan(body_fun: Callable) -> Callable: + """Decorator that adds a progress bar to `body_fun` used in `lax.scan`.""" + + def wrapper_progress_bar(carry: Any, x: Union[tuple, int]) -> Any: + + # Get iteration number + if type(x) is tuple: + iter_num, *_ = x + else: + iter_num = x + + # Compute iteration step + result = body_fun(carry, x) + + # Get loss value + *_, loss_val = result + + # Update progress bar + _update_progress_bar(loss_val, iter_num) + + return result + + return wrapper_progress_bar + + return _progress_bar_scan __all__ = [ diff --git a/gpjax/config.py b/gpjax/config.py index ea6ea560..e8bb528a 100644 --- a/gpjax/config.py +++ b/gpjax/config.py @@ -21,7 +21,7 @@ __config = None -FillTriangular = dx.Chain([tfb.FillTriangular()]) # TODO: Dan to chain methods. +FillTriangular = dx.Chain([tfb.FillTriangular(), ]) # TODO: Dan to chain methods. Identity = dx.Lambda(forward=lambda x: x, inverse=lambda x: x) Softplus = dx.Lambda( forward=lambda x: jnp.log(1 + jnp.exp(x)), @@ -75,7 +75,7 @@ def add_parameter(param_name: str, bijection: dx.Bijector) -> None: Args: param_name (str): The name of the parameter that is to be added. - bijection (tfb.Bijector): The bijection that should be used to unconstrain the parameter's value. + bijection (dx.Bijector): The bijection that should be used to unconstrain the parameter's value. """ lookup_name = f"{param_name}_transform" get_defaults() diff --git a/gpjax/gps.py b/gpjax/gps.py index c6ca7f39..0fb135ac 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -203,15 +203,24 @@ def predict( should be evaluated at. The mean function's value at these points is then returned. """ - gram = self.kernel.gram jitter = get_defaults()["jitter"] + # Unpack mean function and kernel + mean_function = self.mean_function + kernel = self.kernel + + # Unpack kernel computation + gram = kernel.gram + def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: - t = test_inputs - n_test = t.shape[0] - μt = self.mean_function(t, params["mean_function"]) - Ktt = gram(self.kernel, t, params["kernel"]) - Ktt += jitter * I(n_test) + + # Unpack test inputs + t = test_inputs + n_test = test_inputs.shape[0] + + μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, t, params["kernel"]) + Ktt += I(n_test) * jitter Lt = Ktt.triangular_lower() return dx.MultivariateNormalTri(jnp.atleast_1d(μt.squeeze()), Lt) @@ -375,29 +384,37 @@ def predict( """ jitter = get_defaults()["jitter"] + # Unpack training data x, y, n = train_data.X, train_data.y, train_data.n - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + cross_covariance = kernel.cross_covariance # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] - μx = self.prior.mean_function(x, params["mean_function"]) + μx = mean_function(x, params["mean_function"]) - # Precompute covariance matrices - Kxx = gram(self.prior.kernel, x, params["kernel"]) + # Precompute Gram matrix, Kxx, at training inputs, x + Kxx = gram(kernel, x, params["kernel"]) Kxx += I(n) * jitter # Σ = Kxx + Iσ² Sigma = Kxx + I(n) * obs_noise def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution: - t = test_inputs - n_test = t.shape[0] - μt = self.prior.mean_function(t, params["mean_function"]) - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kxt = cross_covariance(self.prior.kernel, x, t, params["kernel"]) + + # Unpack test inputs + t = test_inputs + n_test = test_inputs.shape[0] + + μt = mean_function(t, params["mean_function"]) + Ktt = gram(kernel, t, params["kernel"]) + Kxt = cross_covariance(kernel, x, t, params["kernel"]) # TODO: Investigate lower triangular solves for general covariance operators # this is more efficient than the full solve for dense matrices in the current implimentation. @@ -480,25 +497,33 @@ def marginal_log_likelihood( Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ jitter = get_defaults()["jitter"] + + # Unpack training data x, y, n = train_data.X, train_data.y, train_data.n - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + + # The sign of the marginal log-likelihood depends on whether we are maximising or minimising + constant = jnp.array(-1.0) if negative else jnp.array(1.0) def mll( params: Dict, ): # Observation noise σ² obs_noise = params["likelihood"]["obs_noise"] - μx = self.prior.mean_function(x, params["mean_function"]) - Kxx = gram(self.prior.kernel, x, params["kernel"]) - Kxx += I(n) * jitter + μx = mean_function(x, params["mean_function"]) # TODO: This implementation does not take advantage of the covariance operator structure. # Future work concerns implementation of a custom Gaussian distribution / measure object that accepts a covariance operator. # Σ = (Kxx + Iσ²) = LLᵀ + Kxx = gram(kernel, x, params["kernel"]) + Kxx += I(n) * jitter Sigma = Kxx + I(n) * obs_noise L = Sigma.triangular_lower() @@ -511,7 +536,6 @@ def mll( # log p(θ) log_prior_density = evaluate_priors(params, priors) - constant = jnp.array(-1.0) if negative else jnp.array(1.0) return constant * ( marginal_likelihood.log_prob(jnp.atleast_1d(y.squeeze())).squeeze() + log_prior_density @@ -543,7 +567,14 @@ class NonConjugatePosterior(AbstractPosterior): name: Optional[str] = "Non-conjugate posterior" def _initialise_params(self, key: PRNGKeyType) -> Dict: - """Initialise the parameter set of a non-conjugate GP posterior.""" + """Initialise the parameter set of a non-conjugate GP posterior. + + Args: + key (PRNGKeyType): A PRNG key used to initialise the parameters. + + Returns: + Dict: A dictionary containing the default parameter set. + """ parameters = concat_dictionaries( self.prior._initialise_params(key), {"likelihood": self.likelihood._initialise_params(key)}, @@ -570,37 +601,48 @@ def predict( tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `dx.Distribution`. """ jitter = get_defaults()["jitter"] + + # Unpack training data x, n = train_data.X, train_data.n - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) - Kxx = gram(self.prior.kernel, x, params["kernel"]) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + cross_covariance = kernel.cross_covariance + + # Precompute lower triangular of Gram matrix, Lx, at training inputs, x + Kxx = gram(kernel, x, params["kernel"]) Kxx += I(n) * jitter + Lx = Kxx.triangular_lower() def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution: - t = test_inputs - n_test = t.shape[0] - Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"]) - Ktt = gram(self.prior.kernel, t, params["kernel"]) + I(n_test) * jitter - μt = self.prior.mean_function(t, params["mean_function"]) - Lx = Kxx.triangular_lower() + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + + # Compute terms of the posterior predictive distribution + Ktx = cross_covariance(kernel, t, x, params["kernel"]) + Ktt = gram(kernel, t, params["kernel"]) + I(n_test) * jitter + μt = mean_function(t, params["mean_function"]) # Lx⁻¹ Kxt Lx_inv_Kxt = jsp.linalg.solve_triangular(Lx, Ktx.T, lower=True) - # μt + Ktx Lx⁻¹ latent - mean = μt + jnp.matmul(Lx_inv_Kxt.T, params["latent"]) + # Whitened function values, wx, correponding to the inputs, x + wx = params["latent"] + + # μt + Ktx Lx⁻¹ wx + mean = μt + jnp.matmul(Lx_inv_Kxt.T, wx) # Ktt - Ktx Kxx⁻¹ Kxt - covariance = Ktt + covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) covariance += I(n_test) * jitter - covariance = covariance.to_dense() - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt) return dx.MultivariateNormalFullCovariance( - jnp.atleast_1d(mean.squeeze()), covariance + jnp.atleast_1d(mean.squeeze()), covariance.to_dense() ) return predict_fn @@ -624,28 +666,50 @@ def marginal_log_likelihood( Callable[[Dict], Float[Array, "1"]]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ jitter = get_defaults()["jitter"] + + # Unpack dataset x, y, n = train_data.X, train_data.y, train_data.n - gram = self.prior.kernel.gram + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + + # Link function of the likelihood + link_function = self.likelihood.link_function + + # We induce whitened prior on the latent function if not priors: priors = copy_dict_structure(self._initialise_params(jr.PRNGKey(0))) priors["latent"] = dx.Normal(loc=0.0, scale=1.0) + # The sign of the marginal log-likelihood depends on whether we are maximising or minimising + constant = jnp.array(-1.0) if negative else jnp.array(1.0) + def mll(params: Dict): - Kxx = gram(self.prior.kernel, x, params["kernel"]) + + # Compute lower triangular of the kernel Gram matrix + Kxx = gram(kernel, x, params["kernel"]) Kxx += I(n) * jitter Lx = Kxx.triangular_lower() - μx = self.prior.mean_function(x, params["mean_function"]) - # f(x) = μx + Lx latent - fx = μx + jnp.matmul(Lx, params["latent"]) + # Compute the prior mean function + μx = mean_function(x, params["mean_function"]) - # p(y | f(x), θ), where θ are the model hyperparameters: - likelihood = self.likelihood.link_function(fx, params) + # Whitened function values, wx, correponding to the inputs, x + wx = params["latent"] + + # f(x) = μx + Lx wx + fx = μx + jnp.matmul(Lx, wx) + + # p(y | f(x), θ), where θ are the model hyperparameters + likelihood = link_function(fx, params) # log p(θ) log_prior_density = evaluate_priors(params, priors) - constant = jnp.array(-1.0) if negative else jnp.array(1.0) return constant * (likelihood.log_prob(y).sum() + log_prior_density) return mll @@ -670,26 +734,13 @@ def construct_posterior( elif isinstance(likelihood, NonConjugate): PosteriorGP = NonConjugatePosterior + else: raise NotImplementedError( f"No posterior implemented for {likelihood.name} likelihood" ) - return PosteriorGP(prior=prior, likelihood=likelihood) - - -def euclidean_distance( - x: Float[Array, "N D"], y: Float[Array, "N D"] -) -> Float[Array, "N"]: - """Compute the Euclidean distance between two arrays of points. - Args: - x (Float[Array, "N D"]): An array of points. - y (Float[Array, "N D"]): An array of points. - - Returns: - Float[Array, "N"]: An array of distances. - """ - return jnp.linalg.norm(x[:, None, :] - y[None, :, :], axis=-1) + return PosteriorGP(prior=prior, likelihood=likelihood) __all__ = [ diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 349d434c..70e4a384 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -307,11 +307,11 @@ def __call__( Args: x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. params (Dict): Parameter set for which the kernel should be evaluated on. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / params["lengthscale"] y = self.slice_input(y) / params["lengthscale"] @@ -385,11 +385,11 @@ def __call__( Args: x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. params (Dict): Parameter set for which the kernel should be evaluated on. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / params["lengthscale"] y = self.slice_input(y) / params["lengthscale"] @@ -427,11 +427,11 @@ def __call__( Args: x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. params (Dict): Parameter set for which the kernel should be evaluated on. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)`. """ x = self.slice_input(x) / params["lengthscale"] y = self.slice_input(y) / params["lengthscale"] @@ -475,7 +475,7 @@ def __call__( params (Dict): Parameter set for which the kernel should be evaluated on. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)`. """ x = self.slice_input(x).squeeze() y = self.slice_input(y).squeeze() @@ -504,11 +504,11 @@ def __call__( Args: x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. params (Dict): Parameter set for which the kernel should be evaluated on. Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` + Float[Array, "1"]: The value of :math:`k(x, y)`. """ K = jnp.all(jnp.equal(x, y)) * params["variance"] return K.squeeze() @@ -549,8 +549,8 @@ def __call__( """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. Args: - x (Float[Array, "1 D"]): Index of the ith vertex - y (Float[Array, "1 D"]): Index of the jth vertex + x (Float[Array, "1 D"]): Index of the ith vertex. + y (Float[Array, "1 D"]): Index of the jth vertex. params (Dict): Parameter set for which the kernel should be evaluated on. Returns: diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 9ab21d14..a6a58072 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -31,11 +31,12 @@ class AbstractMeanFunction: name: Optional[str] = "Mean function" @abc.abstractmethod - def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: + def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. This method is required for all subclasses. Args: x (Float[Array, "N D"]): The input points at which to evaluate the mean function. + params (Dict): The parameters of the mean function. Returns: Float[Array, "N Q"]: The mean function evaluated point-wise on the inputs. diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 62ffce34..29081f02 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -133,18 +133,28 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ - gram = self.prior.kernel.gram + + # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] - m = self.num_inducing z = params["variational_family"]["inducing_inputs"] - μz = self.prior.mean_function(z, params["mean_function"]) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + m = self.num_inducing + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + + μz = mean_function(z, params["mean_function"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) pu = dx.MultivariateNormalTri(jnp.atleast_1d(μz.squeeze()), Lz) + return kld_dense_dense(qu, pu) def predict( @@ -162,29 +172,36 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + + # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel - Kzz = gram(self.prior.kernel, z, params["kernel"]) + # Unpack kernel computation + gram = kernel.gram + cross_covariance = kernel.cross_covariance + + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() - μz = self.prior.mean_function(z, params["mean_function"]) + μz = mean_function(z, params["mean_function"]) def predict_fn( test_inputs: Float[Array, "N D"] ) -> dx.MultivariateNormalFullCovariance: - t = test_inputs - n_test = t.shape[0] - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) - μt = self.prior.mean_function(t, params["mean_function"]) + + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + + Ktt = gram(kernel, t, params["kernel"]) + Kzt = cross_covariance(kernel, z, t, params["kernel"]) + μt = mean_function(t, params["mean_function"]) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -198,7 +215,7 @@ def predict_fn( # μt + Ktz Kzz⁻¹ (μ - μz) mean = μt + jnp.matmul(Kzz_inv_Kzt.T, mu - μz) - # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] + # Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt [recall S = sqrt sqrtᵀ] covariance = ( Ktt - jnp.matmul(Lz_inv_Kzt.T, Lz_inv_Kzt) @@ -235,9 +252,12 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + + # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] + # Compute whitened KL divergence qu = dx.MultivariateNormalTri(jnp.atleast_1d(mu.squeeze()), sqrt) return kld_dense_white(qu) @@ -256,28 +276,35 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + + # Unpack variational parameters mu = params["variational_family"]["moments"]["variational_mean"] sqrt = params["variational_family"]["moments"]["variational_root_covariance"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + cross_covariance = kernel.cross_covariance - Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() def predict_fn( test_inputs: Float[Array, "N D"] ) -> dx.MultivariateNormalFullCovariance: - t = test_inputs - n_test = t.shape[0] - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) - μt = self.prior.mean_function(t, params["mean_function"]) + + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + + Ktt = gram(kernel, t, params["kernel"]) + Kzt = cross_covariance(kernel, z, t, params["kernel"]) + μt = mean_function(t, params["mean_function"]) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -295,6 +322,7 @@ def predict_fn( + jnp.matmul(Ktz_Lz_invT_sqrt, Ktz_Lz_invT_sqrt.T) ) covariance += I(n_test) * self.jitter + return dx.MultivariateNormalFullCovariance( jnp.atleast_1d(mean.squeeze()), covariance.to_dense() ) @@ -346,11 +374,19 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + + # Unpack variational parameters natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram = self.prior.kernel.gram + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix @@ -370,8 +406,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - μz = self.prior.mean_function(z, params["mean_function"]) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + μz = mean_function(z, params["mean_function"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() @@ -397,15 +433,20 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + + # Unpack variational parameters natural_vector = params["variational_family"]["moments"]["natural_vector"] natural_matrix = params["variational_family"]["moments"]["natural_matrix"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram = kernel.gram + cross_covariance = kernel.cross_covariance # S⁻¹ = -2θ₂ S_inv = -2 * natural_matrix @@ -425,17 +466,19 @@ def predict( # μ = Sθ₁ mu = jnp.matmul(S, natural_vector) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() - μz = self.prior.mean_function(z, params["mean_function"]) + μz = mean_function(z, params["mean_function"]) def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: - t = test_inputs - n_test = t.shape[0] - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) - μt = self.prior.mean_function(t, params["mean_function"]) + + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + + Ktt = gram(kernel, t, params["kernel"]) + Kzt = cross_covariance(kernel, z, t, params["kernel"]) + μt = mean_function(t, params["mean_function"]) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -510,6 +553,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: Returns: Float[Array, "1"]: The KL-divergence between our variational approximation and the GP prior. """ + + # Unpack variational parameters expectation_vector = params["variational_family"]["moments"][ "expectation_vector" ] @@ -518,7 +563,13 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: ] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram = self.prior.kernel.gram + + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel compuation + gram = kernel.gram # μ = η₁ mu = expectation_vector @@ -530,8 +581,8 @@ def prior_kl(self, params: Dict) -> Float[Array, "1"]: # S = sqrt sqrtᵀ sqrt = jnp.linalg.cholesky(S) - μz = self.prior.mean_function(z, params["mean_function"]) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + μz = mean_function(z, params["mean_function"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() @@ -557,6 +608,8 @@ def predict( Returns: Callable[[Float[Array, "N D"]], dx.MultivariateNormalTri]: A function that accepts a set of test points and will return the predictive distribution at those points. """ + + # Unpack variational parameters expectation_vector = params["variational_family"]["moments"][ "expectation_vector" ] @@ -566,10 +619,13 @@ def predict( z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel compuation + gram = kernel.gram + cross_covariance = kernel.cross_covariance # μ = η₁ mu = expectation_vector @@ -581,19 +637,21 @@ def predict( # S = sqrt sqrtᵀ sqrt = jnp.linalg.cholesky(S) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter Lz = Kzz.triangular_lower() - μz = self.prior.mean_function(z, params["mean_function"]) + μz = mean_function(z, params["mean_function"]) def predict_fn( test_inputs: Float[Array, "N D"] ) -> dx.MultivariateNormalFullCovariance: - t = test_inputs - n_test = t.shape[0] - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) - μt = self.prior.mean_function(t, params["mean_function"]) + + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + + Ktt = gram(kernel, t, params["kernel"]) + Kzt = cross_covariance(kernel, z, t, params["kernel"]) + μt = mean_function(t, params["mean_function"]) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) @@ -667,19 +725,28 @@ def predict_fn( test_inputs: Float[Array, "N D"] ) -> dx.MultivariateNormalFullCovariance: # TODO - can we cache some of this? - x, y = train_data.X, train_data.y - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + # Unpack test inputs + t, n_test = test_inputs, test_inputs.shape[0] + # Unpack training data + x, y = train_data.X, train_data.y + + # Unpack variational parameters noise = params["likelihood"]["obs_noise"] z = params["variational_family"]["inducing_inputs"] m = self.num_inducing - Kzx = cross_covariance(self.prior.kernel, z, x, params["kernel"]) - Kzz = gram(self.prior.kernel, z, params["kernel"]) + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel compuation + gram = kernel.gram + cross_covariance = kernel.cross_covariance + + Kzx = cross_covariance(kernel, z, x, params["kernel"]) + Kzz = gram(kernel, z, params["kernel"]) Kzz += I(m) * self.jitter # Lz Lzᵀ = Kzz @@ -697,7 +764,7 @@ def predict_fn( # LLᵀ = I + AAᵀ L = jnp.linalg.cholesky(jnp.eye(m) + AAT) - μx = self.prior.mean_function(x, params["mean_function"]) + μx = mean_function(x, params["mean_function"]) diff = y - μx # Lz⁻¹ Kzx (y - μx) @@ -709,11 +776,10 @@ def predict_fn( Kzz_inv_Kzx_diff = jsp.linalg.solve_triangular( Lz.T, Lz_inv_Kzx_diff, lower=False ) - t = test_inputs - n_test = t.shape[0] - Ktt = gram(self.prior.kernel, t, params["kernel"]) - Kzt = cross_covariance(self.prior.kernel, z, t, params["kernel"]) - μt = self.prior.mean_function(t, params["mean_function"]) + + Ktt = gram(kernel, t, params["kernel"]) + Kzt = cross_covariance(kernel, z, t, params["kernel"]) + μt = mean_function(t, params["mean_function"]) # Lz⁻¹ Kzt Lz_inv_Kzt = jsp.linalg.solve_triangular(Lz, Kzt, lower=True) diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index c95ddac6..07ca4a17 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -26,7 +26,7 @@ from .gps import AbstractPosterior from .likelihoods import Gaussian from .quadrature import gauss_hermite_quadrature -from .types import Dataset +from .types import Dataset, PRNGKeyType from .utils import concat_dictionaries from .variational_families import ( AbstractVariationalFamily, @@ -45,7 +45,7 @@ def __post_init__(self): self.prior = self.posterior.prior self.likelihood = self.posterior.likelihood - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Construct the parameter set used within the variational scheme adopted.""" hyperparams = concat_dictionaries( {"likelihood": self.posterior.likelihood._initialise_params(key)}, @@ -90,13 +90,15 @@ def elbo( Returns: Callable[[Dict, Dataset], Array]: A callable function that accepts a current parameter estimate and batch of data for which gradients should be computed. """ + + # Constant for whether or not to negate the elbo for optimisation purposes constant = jnp.array(-1.0) if negative else jnp.array(1.0) def elbo_fn(params: Dict, batch: Dataset) -> Float[Array, "1"]: # KL[q(f(·)) || p(f(·))] kl = self.variational_family.prior_kl(params) - # ∫[log(p(y|f(x))) q(f(x))] df(x) + # ∫[log(p(y|f(·))) q(f(·))] df(·) var_exp = self.variational_expectation(params, batch) # For batch size b, we compute n/b * Σᵢ[ ∫log(p(y|f(xᵢ))) q(f(xᵢ)) df(xᵢ)] - KL[q(f(·)) || p(f(·))] @@ -116,19 +118,21 @@ def variational_expectation( Returns: Array: The expectation of the model's log-likelihood under our variational distribution. """ + + # Unpack training batch x, y = batch.X, batch.y - # q(f(x)) - predictive_dist = vmap(self.variational_family.predict(params))(x[:, None]) - mean = predictive_dist.mean().val.reshape(-1, 1) - variance = predictive_dist.variance().val.reshape(-1, 1) + # Variational distribution q(f(·)) = N(f(·); μ(·), Σ(·, ·)) + q = self.variational_family + + # Compute variational mean, μ(x), and variance, √diag(Σ(x, x)), at training inputs, x + qx = vmap(q(params))(x[:, None]) + mean = qx.mean().val.reshape(-1, 1) + variance = qx.variance().val.reshape(-1, 1) # log(p(y|f(x))) - log_prob = vmap( - lambda f, y: self.likelihood.link_function( - f, params["likelihood"] - ).log_prob(y) - ) + link_function = self.likelihood.link_function + log_prob = vmap(lambda f, y: link_function(f, params["likelihood"]).log_prob(y)) # ≈ ∫[log(p(y|f(x))) q(f(x))] df(x) expectation = gauss_hermite_quadrature(log_prob, mean, variance, y=y) @@ -164,26 +168,31 @@ def elbo( Returns: Callable[[Dict, Dataset], Array]: A callable function that accepts a current parameter estimate for which gradients should be computed. """ - constant = jnp.array(-1.0) if negative else jnp.array(1.0) + # Unpack training data x, y, n = train_data.X, train_data.y, train_data.n + # Unpack mean function and kernel + mean_function = self.prior.mean_function + kernel = self.prior.kernel + + # Unpack kernel computation + gram, cross_covariance = kernel.gram, kernel.cross_covariance + m = self.num_inducing - gram, cross_covariance = ( - self.prior.kernel.gram, - self.prior.kernel.cross_covariance, - ) + jitter = self.variational_family.jitter + + # Constant for whether or not to negate the elbo for optimisation purposes + constant = jnp.array(-1.0) if negative else jnp.array(1.0) def elbo_fn(params: Dict) -> Float[Array, "1"]: noise = params["likelihood"]["obs_noise"] z = params["variational_family"]["inducing_inputs"] - Kzz = gram(self.prior.kernel, z, params["kernel"]) - Kzz += I(m) * self.variational_family.jitter - Kzx = cross_covariance(self.prior.kernel, z, x, params["kernel"]) - Kxx_diag = vmap(self.prior.kernel, in_axes=(0, 0, None))( - x, x, params["kernel"] - ) - μx = self.prior.mean_function(x, params["mean_function"]) + Kzz = gram(kernel, z, params["kernel"]) + Kzz += I(m) * jitter + Kzx = cross_covariance(kernel, z, x, params["kernel"]) + Kxx_diag = vmap(kernel, in_axes=(0, 0, None))(x, x, params["kernel"]) + μx = mean_function(x, params["mean_function"]) Lz = Kzz.triangular_lower()