Skip to content

Commit

Permalink
Update typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Aug 23, 2022
1 parent d84ede9 commit e0b60d4
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 188 deletions.
36 changes: 18 additions & 18 deletions gpjax/abstractions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import typing as tp
from typing import Callable, Dict, Optional

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +16,7 @@

@dataclass(frozen=True)
class InferenceState:
params: tp.Dict
params: Dict
history: f64["n_iters"]

def unpack(self):
Expand Down Expand Up @@ -96,24 +96,24 @@ def wrapper_progress_bar(carry, x):


def fit(
objective: tp.Callable,
params: tp.Dict,
trainables: tp.Dict,
objective: Callable,
params: Dict,
trainables: Dict,
optax_optim,
n_iters: int = 100,
log_rate: int = 10,
) -> InferenceState:
"""Abstracted method for fitting a GP model with respect to a supplied objective function.
Optimisers used here should originate from Optax.
Args:
objective (tp.Callable): The objective function that we are optimising with respect to.
params (dict): The parameters for which we would like to minimise our objective function with.
trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
objective (Callable): The objective function that we are optimising with respect to.
params (Dict): The parameters for which we would like to minimise our objective function with.
trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set.
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
tp.Tuple[Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
"""
opt_state = optax_optim.init(params)

Expand All @@ -138,30 +138,30 @@ def step(carry, iter_num):


def fit_batches(
objective: tp.Callable,
params: tp.Dict,
trainables: tp.Dict,
objective: Callable,
params: Dict,
trainables: Dict,
train_data: Dataset,
optax_optim,
key: PRNGKeyType,
batch_size: int,
n_iters: tp.Optional[int] = 100,
log_rate: tp.Optional[int] = 10,
n_iters: Optional[int] = 100,
log_rate: Optional[int] = 10,
) -> InferenceState:
"""Abstracted method for fitting a GP model with mini-batches respect to a supplied objective function.
Optimisers used here should originate from Optax.
Args:
objective (tp.Callable): The objective function that we are optimising with respect to.
params (dict): The parameters for which we would like to minimise our objective function with.
trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
objective (Callable): The objective function that we are optimising with respect to.
params (Dict): The parameters for which we would like to minimise our objective function with.
trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained.
train_data (Dataset): The training dataset.
optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set.
key (PRNGKeyType): The PRNG key for the mini-batch sampling.
batch_size(int): The batch_size.
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
tp.Tuple[Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
"""

opt_state = optax_optim.init(params)
Expand Down
81 changes: 40 additions & 41 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import typing as tp
from abc import abstractmethod, abstractproperty
from typing import Dict
from abc import abstractmethod
from typing import Any, Callable, Dict, Optional

import distrax as dx
import jax.numpy as jnp
Expand All @@ -21,7 +20,7 @@
)
from .mean_functions import AbstractMeanFunction, Zero
from .parameters import copy_dict_structure, evaluate_priors, transform
from .types import Dataset
from .types import Dataset, PRNGKeyType
from .utils import I, concat_dictionaries

DEFAULT_JITTER = get_defaults()["jitter"]
Expand All @@ -31,7 +30,7 @@
class AbstractGP:
"""Abstract Gaussian process object."""

def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution:
def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution:
"""Evaluate the Gaussian process at the given points.
Returns:
Expand All @@ -40,11 +39,11 @@ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution:
return self.predict(*args, **kwargs)

@abstractmethod
def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution:
def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution:
"""Compute the latent function's multivariate normal distribution."""
raise NotImplementedError

def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
def _initialise_params(self, key: PRNGKeyType) -> Dict:
"""Initialise the GP's parameter set"""
raise NotImplementedError

Expand All @@ -57,9 +56,9 @@ class Prior(AbstractGP):
"""A Gaussian process prior object. The GP is parameterised by a mean and kernel function."""

kernel: Kernel
mean_function: tp.Optional[AbstractMeanFunction] = Zero()
name: tp.Optional[str] = "GP prior"
jitter: tp.Optional[float] = DEFAULT_JITTER
mean_function: Optional[AbstractMeanFunction] = Zero()
name: Optional[str] = "GP prior"
jitter: Optional[float] = DEFAULT_JITTER

def __mul__(self, other: AbstractLikelihood):
"""The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned.
Expand All @@ -74,12 +73,12 @@ def __rmul__(self, other: AbstractLikelihood):
"""Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior."""
return self.__mul__(other)

def predict(self, params: dict) -> tp.Callable[[f64["N D"]], dx.Distribution]:
def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]:
"""Compute the GP's prior mean and variance.
Args:
params (dict): The specific set of parameters for which the mean function should be defined for.
params (Dict): The specific set of parameters for which the mean function should be defined for.
Returns:
tp.Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned.
Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned.
"""

def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:
Expand All @@ -95,7 +94,7 @@ def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:

return predict_fn

def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
def _initialise_params(self, key: PRNGKeyType) -> Dict:
"""Initialise the GP prior's parameter set"""
return {
"kernel": self.kernel._initialise_params(key),
Expand All @@ -112,15 +111,15 @@ class AbstractPosterior(AbstractGP):

prior: Prior
likelihood: AbstractLikelihood
name: tp.Optional[str] = "GP posterior"
jitter: tp.Optional[float] = DEFAULT_JITTER
name: Optional[str] = "GP posterior"
jitter: Optional[float] = DEFAULT_JITTER

@abstractmethod
def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution:
def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution:
"""Predict the GP's output given the input."""
raise NotImplementedError

def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
def _initialise_params(self, key: PRNGKeyType) -> Dict:
"""Initialise the parameter set of a GP posterior."""
return concat_dictionaries(
self.prior._initialise_params(key),
Expand All @@ -134,20 +133,20 @@ class ConjugatePosterior(AbstractPosterior):

prior: Prior
likelihood: Gaussian
name: tp.Optional[str] = "Conjugate posterior"
jitter: tp.Optional[float] = DEFAULT_JITTER
name: Optional[str] = "Conjugate posterior"
jitter: Optional[float] = DEFAULT_JITTER

def predict(
self, train_data: Dataset, params: dict
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
self, train_data: Dataset, params: Dict
) -> Callable[[f64["N D"]], dx.Distribution]:
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density.
Args:
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
params (dict): A dictionary of parameters that should be used to compute the posterior.
params (Dict): A dictionary of parameters that should be used to compute the posterior.
Returns:
tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`.
Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`.
"""
x, y, n = train_data.X, train_data.y, train_data.n

Expand Down Expand Up @@ -193,24 +192,24 @@ def marginal_log_likelihood(
self,
train_data: Dataset,
transformations: Dict,
priors: dict = None,
priors: Dict = None,
negative: bool = False,
) -> tp.Callable[[dict], f64["1"]]:
) -> Callable[[Dict], f64["1"]]:
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.
Args:
train_data (Dataset): The training dataset used to compute the marginal log-likelihood.
transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters.
priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None.
priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None.
negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False.
Returns:
tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set.
Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set.
"""
x, y, n = train_data.X, train_data.y, train_data.n

def mll(
params: dict,
params: Dict,
):
params = transform(params=params, transform_map=transformations)

Expand Down Expand Up @@ -247,10 +246,10 @@ class NonConjugatePosterior(AbstractPosterior):

prior: Prior
likelihood: NonConjugateLikelihoodType
name: tp.Optional[str] = "Non-conjugate posterior"
jitter: tp.Optional[float] = DEFAULT_JITTER
name: Optional[str] = "Non-conjugate posterior"
jitter: Optional[float] = DEFAULT_JITTER

def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
def _initialise_params(self, key: PRNGKeyType) -> Dict:
"""Initialise the parameter set of a non-conjugate GP posterior."""
parameters = concat_dictionaries(
self.prior._initialise_params(key),
Expand All @@ -260,16 +259,16 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
return parameters

def predict(
self, train_data: Dataset, params: dict
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
self, train_data: Dataset, params: Dict
) -> Callable[[f64["N D"]], dx.Distribution]:
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function.
Args:
train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset.
params (dict): A dictionary of parameters that should be used to compute the posterior.
params (Dict): A dictionary of parameters that should be used to compute the posterior.
Returns:
tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`.
Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`.
"""
x, n = train_data.X, train_data.n

Expand Down Expand Up @@ -304,27 +303,27 @@ def marginal_log_likelihood(
self,
train_data: Dataset,
transformations: Dict,
priors: dict = None,
priors: Dict = None,
negative: bool = False,
) -> tp.Callable[[dict], f64["1"]]:
) -> Callable[[Dict], f64["1"]]:
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax.
Args:
train_data (Dataset): The training dataset used to compute the marginal log-likelihood.
transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters.
priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None.
priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None.
negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False.
Returns:
tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set.
Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set.
"""
x, y, n = train_data.X, train_data.y, train_data.n

if not priors:
priors = copy_dict_structure(self._initialise_params(jr.PRNGKey(0)))
priors["latent"] = dx.Normal(loc=0.0, scale=1.0)

def mll(params: dict):
def mll(params: Dict):
params = transform(params=params, transform_map=transformations)
Kxx = gram(self.prior.kernel, x, params["kernel"])
Kxx += I(n) * self.jitter
Expand Down
Loading

0 comments on commit e0b60d4

Please sign in to comment.