diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index 066e1ec0..d6d03707 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -1,4 +1,4 @@ -import typing as tp +from typing import Callable, Dict, Optional import jax import jax.numpy as jnp @@ -16,7 +16,7 @@ @dataclass(frozen=True) class InferenceState: - params: tp.Dict + params: Dict history: f64["n_iters"] def unpack(self): @@ -96,9 +96,9 @@ def wrapper_progress_bar(carry, x): def fit( - objective: tp.Callable, - params: tp.Dict, - trainables: tp.Dict, + objective: Callable, + params: Dict, + trainables: Dict, optax_optim, n_iters: int = 100, log_rate: int = 10, @@ -106,14 +106,14 @@ def fit( """Abstracted method for fitting a GP model with respect to a supplied objective function. Optimisers used here should originate from Optax. Args: - objective (tp.Callable): The objective function that we are optimising with respect to. - params (dict): The parameters for which we would like to minimise our objective function with. - trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. + objective (Callable): The objective function that we are optimising with respect to. + params (Dict): The parameters for which we would like to minimise our objective function with. + trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. Returns: - tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively. + tp.Tuple[Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively. """ opt_state = optax_optim.init(params) @@ -138,22 +138,22 @@ def step(carry, iter_num): def fit_batches( - objective: tp.Callable, - params: tp.Dict, - trainables: tp.Dict, + objective: Callable, + params: Dict, + trainables: Dict, train_data: Dataset, optax_optim, key: PRNGKeyType, batch_size: int, - n_iters: tp.Optional[int] = 100, - log_rate: tp.Optional[int] = 10, + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, ) -> InferenceState: """Abstracted method for fitting a GP model with mini-batches respect to a supplied objective function. Optimisers used here should originate from Optax. Args: - objective (tp.Callable): The objective function that we are optimising with respect to. - params (dict): The parameters for which we would like to minimise our objective function with. - trainables (dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. + objective (Callable): The objective function that we are optimising with respect to. + params (Dict): The parameters for which we would like to minimise our objective function with. + trainables (Dict): Boolean dictionary of same structure as 'params' that determines which parameters should be trained. train_data (Dataset): The training dataset. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. key (PRNGKeyType): The PRNG key for the mini-batch sampling. @@ -161,7 +161,7 @@ def fit_batches( n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. Returns: - tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively. + tp.Tuple[Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively. """ opt_state = optax_optim.init(params) diff --git a/gpjax/gps.py b/gpjax/gps.py index 551d1a3c..9047c840 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -1,6 +1,5 @@ -import typing as tp -from abc import abstractmethod, abstractproperty -from typing import Dict +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional import distrax as dx import jax.numpy as jnp @@ -21,7 +20,7 @@ ) from .mean_functions import AbstractMeanFunction, Zero from .parameters import copy_dict_structure, evaluate_priors, transform -from .types import Dataset +from .types import Dataset, PRNGKeyType from .utils import I, concat_dictionaries DEFAULT_JITTER = get_defaults()["jitter"] @@ -31,7 +30,7 @@ class AbstractGP: """Abstract Gaussian process object.""" - def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the Gaussian process at the given points. Returns: @@ -40,11 +39,11 @@ def __call__(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: return self.predict(*args, **kwargs) @abstractmethod - def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Compute the latent function's multivariate normal distribution.""" raise NotImplementedError - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the GP's parameter set""" raise NotImplementedError @@ -57,9 +56,9 @@ class Prior(AbstractGP): """A Gaussian process prior object. The GP is parameterised by a mean and kernel function.""" kernel: Kernel - mean_function: tp.Optional[AbstractMeanFunction] = Zero() - name: tp.Optional[str] = "GP prior" - jitter: tp.Optional[float] = DEFAULT_JITTER + mean_function: Optional[AbstractMeanFunction] = Zero() + name: Optional[str] = "GP prior" + jitter: Optional[float] = DEFAULT_JITTER def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the posterior distribution. By computing the product of a GP prior and a likelihood object, a posterior GP object will be returned. @@ -74,12 +73,12 @@ def __rmul__(self, other: AbstractLikelihood): """Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior.""" return self.__mul__(other) - def predict(self, params: dict) -> tp.Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the GP's prior mean and variance. Args: - params (dict): The specific set of parameters for which the mean function should be defined for. + params (Dict): The specific set of parameters for which the mean function should be defined for. Returns: - tp.Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. + Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned. """ def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: @@ -95,7 +94,7 @@ def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution: return predict_fn - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the GP prior's parameter set""" return { "kernel": self.kernel._initialise_params(key), @@ -112,15 +111,15 @@ class AbstractPosterior(AbstractGP): prior: Prior likelihood: AbstractLikelihood - name: tp.Optional[str] = "GP posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "GP posterior" + jitter: Optional[float] = DEFAULT_JITTER @abstractmethod - def predict(self, *args: tp.Any, **kwargs: tp.Any) -> dx.Distribution: + def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Predict the GP's output given the input.""" raise NotImplementedError - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a GP posterior.""" return concat_dictionaries( self.prior._initialise_params(key), @@ -134,20 +133,20 @@ class ConjugatePosterior(AbstractPosterior): prior: Prior likelihood: Gaussian - name: tp.Optional[str] = "Conjugate posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "Conjugate posterior" + jitter: Optional[float] = DEFAULT_JITTER def predict( - self, train_data: Dataset, params: dict - ) -> tp.Callable[[f64["N D"]], dx.Distribution]: + self, train_data: Dataset, params: Dict + ) -> Callable[[f64["N D"]], dx.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (dict): A dictionary of parameters that should be used to compute the posterior. + params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -193,24 +192,24 @@ def marginal_log_likelihood( self, train_data: Dataset, transformations: Dict, - priors: dict = None, + priors: Dict = None, negative: bool = False, - ) -> tp.Callable[[dict], f64["1"]]: + ) -> Callable[[Dict], f64["1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values. Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters. - priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. + priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n def mll( - params: dict, + params: Dict, ): params = transform(params=params, transform_map=transformations) @@ -247,10 +246,10 @@ class NonConjugatePosterior(AbstractPosterior): prior: Prior likelihood: NonConjugateLikelihoodType - name: tp.Optional[str] = "Non-conjugate posterior" - jitter: tp.Optional[float] = DEFAULT_JITTER + name: Optional[str] = "Non-conjugate posterior" + jitter: Optional[float] = DEFAULT_JITTER - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior.""" parameters = concat_dictionaries( self.prior._initialise_params(key), @@ -260,16 +259,16 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict: return parameters def predict( - self, train_data: Dataset, params: dict - ) -> tp.Callable[[f64["N D"]], dx.Distribution]: + self, train_data: Dataset, params: Dict + ) -> Callable[[f64["N D"]], dx.Distribution]: """Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function. Args: train_data (Dataset): A `gpx.Dataset` object that contains the input and output data used for training dataset. - params (dict): A dictionary of parameters that should be used to compute the posterior. + params (Dict): A dictionary of parameters that should be used to compute the posterior. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a `distrax.MultivariateNormalFullCovariance`. """ x, n = train_data.X, train_data.n @@ -304,19 +303,19 @@ def marginal_log_likelihood( self, train_data: Dataset, transformations: Dict, - priors: dict = None, + priors: Dict = None, negative: bool = False, - ) -> tp.Callable[[dict], f64["1"]]: + ) -> Callable[[Dict], f64["1"]]: """Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax. Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. transformations (Dict): A dictionary of transformations that should be applied to the training dataset to unconstrain the parameters. - priors (dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. + priors (Dict, optional): _description_. Optional argument that contains the priors placed on the model's parameters. Defaults to None. negative (bool, optional): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. Returns: - tp.Callable[[dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. + Callable[[Dict], Array]: A functional representation of the marginal log-likelihood that can be evaluated at a given parameter set. """ x, y, n = train_data.X, train_data.y, train_data.n @@ -324,7 +323,7 @@ def marginal_log_likelihood( priors = copy_dict_structure(self._initialise_params(jr.PRNGKey(0))) priors["latent"] = dx.Normal(loc=0.0, scale=1.0) - def mll(params: dict): + def mll(params: Dict): params = transform(params=params, transform_map=transformations) Kxx = gram(self.prior.kernel, x, params["kernel"]) Kxx += I(n) * self.jitter diff --git a/gpjax/kernels.py b/gpjax/kernels.py index ce541a38..4d94e001 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -6,6 +6,8 @@ from jax import vmap from jaxtyping import f64 +from .types import PRNGKeyType + ########################################## # Abtract classes @@ -23,12 +25,12 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) @abc.abstractmethod - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs. Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)`. """ @@ -59,7 +61,7 @@ def ard(self): return True if self.ndims > 1 else False @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """A template dictionary of the kernel's parameter set.""" raise NotImplementedError @@ -97,11 +99,11 @@ def _set_kernels(self, kernels: Sequence[Kernel]) -> None: kernels_list.append(k) self.kernel_set = kernels_list - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """A template dictionary of the kernel's parameter set.""" return [kernel._initialise_params(key) for kernel in self.kernel_set] - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: return self.combination_fn( jnp.stack([k(x, y, p) for k, p in zip(self.kernel_set, params)]) ) @@ -135,7 +137,7 @@ class RBF(Kernel): def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` .. math:: @@ -144,7 +146,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -154,7 +156,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -170,7 +172,7 @@ class Matern12(Kernel): def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` .. math:: @@ -179,7 +181,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` """ @@ -188,7 +190,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: K = params["variance"] * jnp.exp(-0.5 * euclidean_distance(x, y)) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -204,7 +206,7 @@ class Matern32(Kernel): def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma` .. math:: @@ -213,7 +215,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -228,7 +230,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: ) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -244,7 +246,7 @@ class Matern52(Kernel): def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma` .. math:: @@ -253,7 +255,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -268,7 +270,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: ) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -286,7 +288,7 @@ def __post_init__(self): self.ndims = 1 if not self.active_dims else len(self.active_dims) self.name = f"Polynomial Degree: {self.degree}" - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\alpha` and variance :math:`\sigma` through .. math:: @@ -295,7 +297,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: Args: x (jnp.DeviceArray): The left hand argument of the kernel function's call. y (jnp.DeviceArray): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of :math:`k(x, y)` @@ -305,7 +307,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) return K.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "shift": jnp.array([1.0]), "variance": jnp.array([1.0] * self.ndims), @@ -330,13 +332,13 @@ def __post_init__(self): self.evals = evals.reshape(-1, 1) self.num_vertex = self.laplacian.shape[0] - def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: + def __call__(self, x: f64["1 D"], y: f64["1 D"], params: Dict) -> f64["1"]: """Evaluate the graph kernel on a pair of vertices v_i, v_j. Args: x (jnp.DeviceArray): Index of the ith vertex y (jnp.DeviceArray): Index of the jth vertex - params (dict): Parameter set for which the kernel should be evaluated on. + params (Dict): Parameter set for which the kernel should be evaluated on. Returns: Array: The value of k(v_i, v_j). @@ -353,7 +355,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]: ) return kxy.squeeze() - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: return { "lengthscale": jnp.array([1.0] * self.ndims), "variance": jnp.array([1.0]), @@ -371,13 +373,13 @@ def euclidean_distance(x: f64["1 D"], y: f64["1 D"]) -> f64["1"]: return jnp.sqrt(jnp.maximum(jnp.sum((x - y) ** 2), 1e-36)) -def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]: +def gram(kernel: Kernel, inputs: f64["N D"], params: Dict) -> f64["N N"]: """For a given kernel, compute the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`. Args: kernel (Kernel): The kernel for which the Gram matrix should be computed for. inputs (Array): The input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed square Gram matrix. @@ -386,7 +388,7 @@ def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]: def cross_covariance( - kernel: Kernel, x: f64["N D"], y: f64["M D"], params: dict + kernel: Kernel, x: f64["N D"], y: f64["M D"], params: Dict ) -> f64["N M"]: """For a given kernel, compute the :math:`m \times n` gram matrix on an a pair of input matrices with shape :math:`m \times d` and :math:`n \times d` for :math:`d\geq 1`. @@ -394,7 +396,7 @@ def cross_covariance( kernel (Kernel): The kernel for which the cross-covariance matrix should be computed for. x (Array): The first input matrix. y (Array): The second input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed square Gram matrix. @@ -402,12 +404,12 @@ def cross_covariance( return vmap(lambda x1: vmap(lambda y1: kernel(x1, y1, params))(y))(x) -def diagonal(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]: +def diagonal(kernel: Kernel, inputs: f64["N D"], params: Dict) -> f64["N N"]: """For a given kernel, compute the elementwise diagonal of the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`. Args: kernel (Kernel): The kernel for which the variance vector should be computed for. inputs (Array): The input matrix. - params (dict): The kernel's parameter set. + params (Dict): The kernel's parameter set. Returns: Array: The computed diagonal variance matrix. """ diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 4f113d21..7760c138 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -7,6 +7,7 @@ from chex import dataclass from jaxtyping import f64 +from .types import PRNGKeyType from .utils import I @@ -27,7 +28,7 @@ def predict(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the parameters of the likelihood function.""" raise NotImplementedError @@ -54,7 +55,7 @@ class Gaussian(AbstractLikelihood, Conjugate): name: Optional[str] = "Gaussian" - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the variance parameter of the likelihood function.""" return {"obs_noise": jnp.array([1.0])} @@ -66,12 +67,12 @@ def link_function(self) -> Callable: Callable: A link function that maps the predictive distribution to the likelihood function. """ - def link_fn(x, params: dict) -> dx.Distribution: + def link_fn(x, params: Dict) -> dx.Distribution: return dx.Normal(loc=x, scale=params["obs_noise"]) return link_fn - def predict(self, dist: dx.Distribution, params: dict) -> dx.Distribution: + def predict(self, dist: dx.Distribution, params: Dict) -> dx.Distribution: """Evaluate the Gaussian likelihood function at a given predictive distribution. Computationally, this is equivalent to summing the observation noise term to the diagonal elements of the predictive distribution's covariance matrix..""" n_data = dist.event_shape[0] noisy_cov = dist.covariance() + I(n_data) * params["likelihood"]["obs_noise"] @@ -82,7 +83,7 @@ def predict(self, dist: dx.Distribution, params: dict) -> dx.Distribution: class Bernoulli(AbstractLikelihood, NonConjugate): name: Optional[str] = "Bernoulli" - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Initialise the parameter set of a Bernoulli likelihood.""" return {} @@ -113,7 +114,7 @@ def moment_fn(mean: f64["N D"], variance: f64["N D"], params: Dict): return moment_fn - def predict(self, dist: dx.Distribution, params: dict) -> Any: + def predict(self, dist: dx.Distribution, params: Dict) -> Any: variance = jnp.diag(dist.covariance()) mean = dist.mean() return self.predictive_moment_fn(mean.ravel(), variance, params) diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 33e47b0d..77e98380 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -5,6 +5,8 @@ from chex import dataclass from jaxtyping import f64 +from .types import PRNGKeyType + @dataclass(repr=False) class AbstractMeanFunction: @@ -26,11 +28,11 @@ def __call__(self, x: f64["N D"]) -> f64["N Q"]: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. Returns: - dict: The parameters of the mean function. + Dict: The parameters of the mean function. """ raise NotImplementedError @@ -44,12 +46,12 @@ class Zero(AbstractMeanFunction): output_dim: Optional[int] = 1 name: Optional[str] = "Zero mean function" - def __call__(self, x: f64["N D"], params: dict) -> f64["N Q"]: + def __call__(self, x: f64["N D"], params: Dict) -> f64["N Q"]: """Evaluate the mean function at the given points. Args: x (Array): The input points at which to evaluate the mean function. - params (dict): The parameters of the mean function. + params (Dict): The parameters of the mean function. Returns: Array: A vector of zeros. @@ -57,7 +59,7 @@ def __call__(self, x: f64["N D"], params: dict) -> f64["N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """The parameters of the mean function. For the zero-mean function, this is an empty dictionary.""" return {} @@ -85,6 +87,6 @@ def __call__(self, x: f64["N D"], params: Dict) -> f64["N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] - def _initialise_params(self, key: jnp.DeviceArray) -> Dict: + def _initialise_params(self, key: PRNGKeyType) -> Dict: """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value.""" return {"constant": jnp.array([1.0])} diff --git a/gpjax/natural_gradients.py b/gpjax/natural_gradients.py index 1f68ac0a..f67dc52b 100644 --- a/gpjax/natural_gradients.py +++ b/gpjax/natural_gradients.py @@ -1,5 +1,5 @@ -import typing as tp from copy import deepcopy +from typing import Callable, Dict, Optional, Tuple import jax import jax.numpy as jnp @@ -32,15 +32,15 @@ def natural_to_expectation( - natural_moments: dict, jitter: float = DEFAULT_JITTER -) -> dict: + natural_moments: Dict, jitter: float = DEFAULT_JITTER +) -> Dict: """ Converts natural parameters to expectation parameters. Args: natural_moments: A dictionary of natural parameters. jitter (float): A small value to prevent numerical instability. Returns: - tp.Dict: A dictionary of Gaussian moments under the expectation parameterisation. + Dict: A dictionary of Gaussian moments under the expectation parameterisation. """ natural_matrix = natural_moments["natural_matrix"] @@ -79,7 +79,7 @@ def _expectation_elbo( posterior: AbstractPosterior, variational_family: AbstractVariationalFamily, train_data: Dataset, -) -> tp.Callable[[dict, Dataset], float]: +) -> Callable[[Dict, Dataset], float]: """ Construct evidence lower bound (ELBO) for variational Gaussian under the expectation parameterisation. Args: @@ -100,13 +100,13 @@ def _expectation_elbo( return svgp.elbo(train_data, identity_transformation, negative=True) -def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: +def _stop_gradients_nonmoments(params: Dict) -> Dict: """ Stops gradients for non-moment parameters. Args: params: A dictionary of parameters. Returns: - tp.Dict: A dictionary of parameters with stopped gradients. + Dict: A dictionary of parameters with stopped gradients. """ trainables = build_trainables_false(params) moment_trainables = build_trainables_true(params["variational_family"]["moments"]) @@ -115,13 +115,13 @@ def _stop_gradients_nonmoments(params: tp.Dict) -> tp.Dict: return params -def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: +def _stop_gradients_moments(params: Dict) -> Dict: """ Stops gradients for moment parameters. Args: params: A dictionary of parameters. Returns: - tp.Dict: A dictionary of parameters with stopped gradients. + Dict: A dictionary of parameters with stopped gradients. """ trainables = build_trainables_true(params) moment_trainables = build_trainables_false(params["variational_family"]["moments"]) @@ -133,8 +133,8 @@ def _stop_gradients_moments(params: tp.Dict) -> tp.Dict: def natural_gradients( stochastic_vi: StochasticVI, train_data: Dataset, - transformations: dict, -) -> tp.Tuple[tp.Callable[[dict, Dataset], dict]]: + transformations: Dict, +) -> Tuple[Callable[[Dict, Dataset], Dict]]: """ Computes the gradient with respect to the natural parameters. Currently only implemented for the natural variational Gaussian family. Args: @@ -143,7 +143,7 @@ def natural_gradients( train_data: A Dataset. transformations: A dictionary of transformations. Returns: - Tuple[tp.Callable[[dict, Dataset], dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. + Tuple[Callable[[Dict, Dataset], Dict]]: Functions that compute natural gradients and hyperparameter gradients respectively. """ posterior = stochastic_vi.posterior variational_family = stochastic_vi.variational_family @@ -156,7 +156,7 @@ def natural_gradients( if isinstance(variational_family, NaturalVariationalGaussian): - def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + def nat_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: """ Computes the natural gradients of the ELBO. Args: @@ -164,7 +164,7 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: trainables: A dictionary of trainables. batch: A Dataset. Returns: - dict: A dictionary of natural gradients. + Dict: A dictionary of natural gradients. """ # Transform parameters to constrained space. params = transform(params, transformations) @@ -180,7 +180,7 @@ def nat_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: expectation_params["variational_family"]["moments"] = expectation_moments # Compute gradient ∂L/∂η: - def loss_fn(params: dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. trains = deepcopy(trainables) trains["variational_family"]["moments"] = build_trainables_true( @@ -211,7 +211,7 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: else: raise NotImplementedError - def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: + def hyper_grads_fn(params: Dict, trainables: Dict, batch: Dataset) -> Dict: """ Computes the hyperparameter gradients of the ELBO. Args: @@ -219,10 +219,10 @@ def hyper_grads_fn(params: dict, trainables: dict, batch: Dataset) -> dict: trainables: A dictionary of trainables. batch: A Dataset. Returns: - dict: A dictionary of hyperparameter gradients. + Dict: A dictionary of hyperparameter gradients. """ - def loss_fn(params: dict, batch: Dataset) -> f64["1"]: + def loss_fn(params: Dict, batch: Dataset) -> f64["1"]: # Determine hyperparameters that should be trained. params = trainable_params(params, trainables) @@ -240,17 +240,17 @@ def loss_fn(params: dict, batch: Dataset) -> f64["1"]: def fit_natgrads( stochastic_vi: StochasticVI, - params: tp.Dict, - trainables: tp.Dict, - transformations: tp.Dict, + params: Dict, + trainables: Dict, + transformations: Dict, train_data: Dataset, batch_size: int, moment_optim, hyper_optim, key: PRNGKeyType, - n_iters: tp.Optional[int] = 100, - log_rate: tp.Optional[int] = 10, -) -> tp.Dict: + n_iters: Optional[int] = 100, + log_rate: Optional[int] = 10, +) -> Dict: hyper_state = hyper_optim.init(params) moment_state = moment_optim.init(params) diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 53be3c78..2c2aa55e 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -1,6 +1,6 @@ -import typing as tp import warnings from copy import deepcopy +from typing import Dict, Tuple from warnings import warn import distrax as dx @@ -24,10 +24,10 @@ class ParameterState: """The state of the model. This includes the parameter set and the functions that allow parameters to be constrained and unconstrained.""" - params: tp.Dict - trainables: tp.Dict - constrainers: tp.Dict - unconstrainers: tp.Dict + params: Dict + trainables: Dict + constrainers: Dict + unconstrainers: Dict def unpack(self): return self.params, self.trainables, self.constrainers, self.unconstrainers @@ -60,7 +60,7 @@ def _validate_kwargs(kwargs, params): raise ValueError(f"Parameter {k} is not a valid parameter.") -def recursive_items(d1: tp.Dict, d2: tp.Dict): +def recursive_items(d1: Dict, d2: Dict): """Recursive loop over pair of dictionaries whereby the value of a given key in either dictionary can be itself a dictionary. Args: @@ -77,15 +77,15 @@ def recursive_items(d1: tp.Dict, d2: tp.Dict): yield (key, value, d2[key]) -def recursive_complete(d1: tp.Dict, d2: tp.Dict) -> tp.Dict: +def recursive_complete(d1: Dict, d2: Dict) -> Dict: """Recursive loop over pair of dictionaries whereby the value of a given key in either dictionary can be itself a dictionary. If the value of the key in the second dictionary is None, the value of the key in the first dictionary is used. Args: - d1 (tp.Dict): The reference dictionary. - d2 (tp.Dict): The potentially incomplete dictionary. + d1 (Dict): The reference dictionary. + d2 (Dict): The potentially incomplete dictionary. Returns: - tp.Dict: A completed form of the second dictionary. + Dict: A completed form of the second dictionary. """ for key, value in d1.items(): if type(value) is dict: @@ -102,14 +102,14 @@ def recursive_complete(d1: tp.Dict, d2: tp.Dict) -> tp.Dict: ################################ # Parameter transformation ################################ -def build_bijectors(params: tp.Dict) -> tp.Dict: +def build_bijectors(params: Dict) -> Dict: """For each parameter, build the bijection pair that allows the parameter to be constrained and unconstrained. Args: - params (tp.Dict): _description_ + params (Dict): _description_ Returns: - tp.Dict: A dictionary that maps each parameter to a bijection. + Dict: A dictionary that maps each parameter to a bijection. """ bijectors = copy_dict_structure(params) config = get_defaults() @@ -118,7 +118,7 @@ def build_bijectors(params: tp.Dict) -> tp.Dict: def recursive_bijectors_list(ps, bs): return [recursive_bijectors(ps[i], bs[i]) for i in range(len(bs))] - def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: + def recursive_bijectors(ps, bs) -> Tuple[Dict, Dict]: if type(ps) is list: bs = recursive_bijectors_list(ps, bs) @@ -143,14 +143,14 @@ def recursive_bijectors(ps, bs) -> tp.Tuple[tp.Dict, tp.Dict]: return recursive_bijectors(params, bijectors) -def build_transforms(params: tp.Dict) -> tp.Tuple[tp.Dict, tp.Dict]: +def build_transforms(params: Dict) -> Tuple[Dict, Dict]: """Using the bijector that is associated with each parameter, construct a pair of functions from the bijector that allow the parameter to be constrained and unconstrained. Args: - params (tp.Dict): The parameter set for which transformations should be derived from. + params (Dict): The parameter set for which transformations should be derived from. Returns: - tp.Tuple[tp.Dict, tp.Dict]: A pair of dictionaries. The first dictionary maps each parameter to a function that constrains the parameter. The second dictionary maps each parameter to a function that unconstrains the parameter. + Tuple[Dict, Dict]: A pair of dictionaries. The first dictionary maps each parameter to a function that constrains the parameter. The second dictionary maps each parameter to a function that unconstrains the parameter. """ def forward(bijector): @@ -172,28 +172,28 @@ def inverse(bijector): return constrainers, unconstrainers -def transform(params: tp.Dict, transform_map: tp.Dict) -> tp.Dict: +def transform(params: Dict, transform_map: Dict) -> Dict: """Transform the parameters according to the constraining or unconstraining function dictionary. Args: - params (tp.Dict): The parameters that are to be transformed. - transform_map (tp.Dict): The corresponding dictionary of transforms that should be applied to the parameter set. + params (Dict): The parameters that are to be transformed. + transform_map (Dict): The corresponding dictionary of transforms that should be applied to the parameter set. Returns: - tp.Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary. + Dict: A transformed parameter set.s The dictionary is equal in structure to the input params dictionary. """ return jax.tree_util.tree_map( lambda param, trans: trans(param), params, transform_map ) -def build_identity(params: tp.Dict) -> tp.Dict: +def build_identity(params: Dict) -> Dict: """ " Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of identity forward/backward bijectors. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -212,7 +212,7 @@ def log_density(param: f64["D"], density: dx.Distribution) -> f64["1"]: return log_prob -def copy_dict_structure(params: dict) -> dict: +def copy_dict_structure(params: Dict) -> Dict: # Copy dictionary structure prior_container = deepcopy(params) # Set all values to zero @@ -220,15 +220,15 @@ def copy_dict_structure(params: dict) -> dict: return prior_container -def structure_priors(params: dict, priors: dict) -> dict: +def structure_priors(params: Dict, priors: Dict) -> Dict: """First create a dictionary with equal structure to the parameters. Then, for each supplied prior, overwrite the None value if it exists. Args: - params (dict): [description] - priors (dict): [description] + params (Dict): [description] + priors (Dict): [description] Returns: - dict: [description] + Dict: [description] """ prior_container = copy_dict_structure(params) # Where a prior has been supplied, override the None value by the prior distribution. @@ -236,14 +236,14 @@ def structure_priors(params: dict, priors: dict) -> dict: return complete_prior -def evaluate_priors(params: dict, priors: dict) -> dict: +def evaluate_priors(params: Dict, priors: Dict) -> Dict: """Recursive loop over pair of dictionaries that correspond to a parameter's current value and the parameter's respective prior distribution. For parameters where a prior distribution is specified, the log-prior density is evaluated at the parameter's current value. - Args: params (dict): Dictionary containing the current set of parameter - estimates. priors (dict): Dictionary specifying the parameters' prior + Args: params (Dict): Dictionary containing the current set of parameter + estimates. priors (Dict): Dictionary specifying the parameters' prior distributions. Returns: Array: The log-prior density, summed over all parameters. @@ -255,7 +255,7 @@ def evaluate_priors(params: dict, priors: dict) -> dict: return lpd -def prior_checks(priors: dict) -> dict: +def prior_checks(priors: Dict) -> Dict: """Run checks on th parameters' prior distributions. This checks that for Gaussian processes that are constructed with non-conjugate likelihoods, the prior distribution on the function's latent values is a unit Gaussian.""" if "latent" in priors.keys(): latent_prior = priors["latent"] @@ -275,14 +275,14 @@ def prior_checks(priors: dict) -> dict: return priors -def build_trainables_true(params: tp.Dict) -> tp.Dict: +def build_trainables_true(params: Dict) -> Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is trainable. Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -291,14 +291,14 @@ def build_trainables_true(params: tp.Dict) -> tp.Dict: return prior_container -def build_trainables_false(params: tp.Dict) -> tp.Dict: +def build_trainables_false(params: Dict) -> Dict: """Construct a dictionary of trainable statuses for each parameter. By default, every parameter within the model is NOT trainable. Args: - params (tp.Dict): The parameter set for which trainable statuses should be derived from. + params (Dict): The parameter set for which trainable statuses should be derived from. Returns: - tp.Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. + Dict: A dictionary of boolean trainability statuses. The dictionary is equal in structure to the input params dictionary. """ # Copy dictionary structure prior_container = deepcopy(params) @@ -307,12 +307,12 @@ def build_trainables_false(params: tp.Dict) -> tp.Dict: return prior_container -def stop_grad(param: tp.Dict, trainable: tp.Dict): +def stop_grad(param: Dict, trainable: Dict): """When taking a gradient, we want to stop the gradient from flowing through a parameter if it is not trainable. This is achieved using the model's dictionary of parameters and the corresponding trainability status.""" return jax.lax.cond(trainable, lambda x: x, jax.lax.stop_gradient, param) -def trainable_params(params: tp.Dict, trainables: tp.Dict) -> tp.Dict: +def trainable_params(params: Dict, trainables: Dict) -> Dict: """Stop the gradients flowing through parameters whose trainable status is False""" return jax.tree_util.tree_map( lambda param, trainable: stop_grad(param, trainable), params, trainables diff --git a/gpjax/types.py b/gpjax/types.py index afa7e1a9..d1871040 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -1,9 +1,9 @@ import jax.numpy as jnp -from chex import PRNGKey as PRNGKeyType from chex import dataclass from jaxtyping import f64 NoneType = type(None) +from chex import PRNGKey as PRNGKeyType @dataclass diff --git a/gpjax/utils.py b/gpjax/utils.py index dbfbc823..cf2cbf53 100644 --- a/gpjax/utils.py +++ b/gpjax/utils.py @@ -1,10 +1,8 @@ -import typing as tp from copy import deepcopy +from typing import Callable, Dict, Tuple import jax import jax.numpy as jnp -import jax.random as jr -from chex import PRNGKey from jaxtyping import f64 @@ -17,7 +15,7 @@ def I(n: int) -> f64["N N"]: return jnp.eye(n) -def concat_dictionaries(a: dict, b: dict) -> dict: +def concat_dictionaries(a: Dict, b: Dict) -> Dict: """ Append one dictionary below another. If duplicate keys exist, then the key-value pair of the second supplied dictionary will be used. @@ -25,7 +23,7 @@ def concat_dictionaries(a: dict, b: dict) -> dict: return {**a, **b} -def merge_dictionaries(base_dict: dict, in_dict: dict) -> dict: +def merge_dictionaries(base_dict: Dict, in_dict: Dict) -> Dict: """ This will return a complete dictionary based on the keys of the first matrix. If the same key should exist in the second matrix, then the key-value pair from the first dictionary will be overwritten. The purpose of this is that @@ -42,7 +40,7 @@ def merge_dictionaries(base_dict: dict, in_dict: dict) -> dict: return base_dict -def sort_dictionary(base_dict: dict) -> dict: +def sort_dictionary(base_dict: Dict) -> Dict: """ Sort a dictionary based on the dictionary's key values. @@ -52,7 +50,7 @@ def sort_dictionary(base_dict: dict) -> dict: return dict(sorted(base_dict.items())) -def as_constant(parameter_set: dict, params: list) -> tp.Tuple[dict, dict]: +def as_constant(parameter_set: Dict, params: list) -> Tuple[Dict, Dict]: base_params = deepcopy(parameter_set) sparams = {} for param in params: @@ -61,21 +59,21 @@ def as_constant(parameter_set: dict, params: list) -> tp.Tuple[dict, dict]: return base_params, sparams -def dict_array_coercion(params: tp.Dict) -> tp.Tuple[tp.Callable, tp.Callable]: +def dict_array_coercion(params: Dict) -> Tuple[Callable, Callable]: """Construct the logic required to map a dictionary of parameters to an array of parameters. The values of the dictionary can themselves be dictionaries; the function should work recursively. Args: - params (tp.Dict): The dictionary of parameters that we would like to map into an array. + params (Dict): The dictionary of parameters that we would like to map into an array. Returns: - tp.Tuple[tp.Callable, tp.Callable]: A pair of functions, the first of which maps a dictionary to an array, and the second of which maps an array to a dictionary. The remapped dictionary is equal in structure to the original dictionary. + Tuple[Callable, Callable]: A pair of functions, the first of which maps a dictionary to an array, and the second of which maps an array to a dictionary. The remapped dictionary is equal in structure to the original dictionary. """ flattened_pytree = jax.tree_util.tree_flatten(params) def dict_to_array(parameter_dict) -> jnp.DeviceArray: return jax.tree_util.tree_flatten(parameter_dict)[0] - def array_to_dict(parameter_array) -> tp.Dict: + def array_to_dict(parameter_array) -> Dict: return jax.tree_util.tree_unflatten(flattened_pytree[1], parameter_array) return dict_to_array, array_to_dict diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index 6e194811..4c7c855c 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -102,7 +102,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as: @@ -110,7 +110,7 @@ def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: N[f(t); μt + Ktz Kzz⁻¹ (μ - μz), Ktt - Ktz Kzz⁻¹ Kzt + Ktz Kzz⁻¹ S Kzz⁻¹ Kzt ]. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -190,7 +190,7 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs t. This is the integral q(f(t)) = ∫ p(f(t)|u) q(u) du, which can be computed in closed form as @@ -198,7 +198,7 @@ def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: N[f(t); μt + Ktz Lz⁻ᵀ μ, Ktt - Ktz Kzz⁻¹ Kzt + Ktz Lz⁻ᵀ S Lz⁻¹ Kzt]. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -302,11 +302,11 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -428,11 +428,11 @@ def prior_kl(self, params: Dict) -> f64["1"]: return qu.kl_divergence(pu) - def predict(self, params: dict) -> Callable[[f64["N D"]], dx.Distribution]: + def predict(self, params: Dict) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. @@ -525,11 +525,11 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ) def predict( - self, train_data: Dataset, params: dict + self, train_data: Dataset, params: Dict ) -> Callable[[f64["N D"]], dx.Distribution]: """Compute the predictive distribution of the GP at the test inputs. Args: - params (dict): The set of parameters that are to be used to parameterise our variational approximation and GP. + params (Dict): The set of parameters that are to be used to parameterise our variational approximation and GP. Returns: Callable[[f64["N D"]], dx.Distribution]: A function that accepts a set of test points and will return the predictive distribution at those points. """ diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index e24b6939..ffa09769 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -141,7 +141,7 @@ def __post_init__(self): def elbo( self, train_data: Dataset, transformations: Dict, negative: bool = False - ) -> Callable[[dict], f64["1"]]: + ) -> Callable[[Dict], f64["1"]]: """Compute the evidence lower bound under this model. In short, this requires evaluating the expectation of the model's log-likelihood under the variational approximation. To this, we sum the KL divergence from the variational posterior to the prior. When batching occurs, the result is scaled by the batch size relative to the full dataset size. Args: