Skip to content

Commit

Permalink
bump jaxtyping
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaspinder committed Sep 8, 2022
1 parent 3b60094 commit 8ac2427
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 74 deletions.
4 changes: 2 additions & 2 deletions examples/classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"import distrax as dx\n",
"from gpjax.utils import I\n",
"import jax.scipy as jsp\n",
"from jaxtyping import f64\n",
"from jaxtyping import Float, Array\n",
"\n",
"key = jr.PRNGKey(123)"
]
Expand Down Expand Up @@ -287,7 +287,7 @@
"from gpjax.kernels import gram, cross_covariance\n",
"\n",
"\n",
"def predict(laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: f64[\"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n",
"def predict(laplace_at_data: dx.Distribution, train_data: Dataset, test_inputs: Float[Array, \"N D\"], jitter: int = 1e-6) -> dx.Distribution:\n",
" \"\"\"Compute the predictive distribution of the Laplace approximation at novel inputs.\n",
"\n",
" Args:\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/kernels.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"metadata": {},
"outputs": [],
"source": [
"import gpjax as gpx\n",
"\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import matplotlib.pyplot as plt\n",
Expand All @@ -25,8 +27,7 @@
"import jax\n",
"from optax import adam\n",
"import distrax as dx\n",
"\n",
"import gpjax as gpx\n",
"from jaxtyping import Float, Array\n",
"\n",
"key = jr.PRNGKey(123)"
]
Expand Down Expand Up @@ -261,7 +262,6 @@
"outputs": [],
"source": [
"from chex import dataclass\n",
"from jaxtyping import f64\n",
"\n",
"\n",
"def angular_distance(x, y, c):\n",
Expand All @@ -275,7 +275,7 @@
" def __post_init__(self):\n",
" self.c = self.period / 2.0 # in [0, \\pi]\n",
"\n",
" def __call__(self, x: f64[\"1 D\"], y: f64[\"1 D\"], params: dict) -> f64[\"1\"]:\n",
" def __call__(self, x: Float[Array, \"1 D\"], y: Float[Array, \"1 D\"], params: dict) -> Float[Array, \"1\"]:\n",
" tau = params[\"tau\"]\n",
" t = angular_distance(x, y, self.c)\n",
" K = (1 + tau * t / self.c) * jnp.clip(1 - t / self.c, 0, jnp.inf) ** tau\n",
Expand Down
8 changes: 4 additions & 4 deletions gpjax/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from chex import dataclass
from jax import lax
from jax.experimental import host_callback
from jaxtyping import f64
from jaxtyping import Array, Float
from tqdm.auto import tqdm

from .parameters import trainable_params
Expand All @@ -17,7 +17,7 @@
@dataclass(frozen=True)
class InferenceState:
params: tp.Dict
history: f64["n_iters"]
history: Float[Array, "n_iters"]

def unpack(self):
return self.params, self.history
Expand Down Expand Up @@ -113,7 +113,7 @@ def fit(
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
"""
opt_state = optax_optim.init(params)

Expand Down Expand Up @@ -161,7 +161,7 @@ def fit_batches(
n_iters (int, optional): The number of optimisation steps to run. Defaults to 100.
log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10.
Returns:
tp.Tuple[tp.Dict, f64["n_iters"]]: A tuple comprising optimised parameters and training history respectively.
InferenceState: An InferenceState object comprising the optimised parameters and training history respectively.
"""

opt_state = optax_optim.init(params)
Expand Down
20 changes: 11 additions & 9 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jax.random as jr
import jax.scipy as jsp
from chex import dataclass
from jaxtyping import f64
from jaxtyping import Array, Float

from .config import get_defaults
from .kernels import Kernel, cross_covariance, gram
Expand Down Expand Up @@ -74,15 +74,17 @@ def __rmul__(self, other: AbstractLikelihood):
"""Reimplement the multiplication operator to allow for order-invariant product of a likelihood and a prior i.e., likelihood * prior."""
return self.__mul__(other)

def predict(self, params: dict) -> tp.Callable[[f64["N D"]], dx.Distribution]:
def predict(
self, params: dict
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
"""Compute the GP's prior mean and variance.
Args:
params (dict): The specific set of parameters for which the mean function should be defined for.
Returns:
tp.Callable[[Array], Array]: A mean function that accepts an input array for where the mean function should be evaluated at. The mean function's value at these points is then returned.
"""

def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:
def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
t = test_inputs
n_test = t.shape[0]
μt = self.mean_function(t, params["mean_function"])
Expand Down Expand Up @@ -139,7 +141,7 @@ class ConjugatePosterior(AbstractPosterior):

def predict(
self, train_data: Dataset, params: dict
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density.
Args:
Expand All @@ -166,7 +168,7 @@ def predict(
# w = L⁻¹ (y - μx)
w = jsp.linalg.solve_triangular(L, y - μx, lower=True)

def predict(test_inputs: f64["N D"]) -> dx.Distribution:
def predict(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
t = test_inputs
n_test = t.shape[0]
μt = self.prior.mean_function(t, params["mean_function"])
Expand Down Expand Up @@ -195,7 +197,7 @@ def marginal_log_likelihood(
transformations: Dict,
priors: dict = None,
negative: bool = False,
) -> tp.Callable[[dict], f64["1"]]:
) -> tp.Callable[[dict], Float[Array, "1"]]:
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here enables exact estimation of the Gaussian process' latent function values.
Args:
Expand Down Expand Up @@ -261,7 +263,7 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:

def predict(
self, train_data: Dataset, params: dict
) -> tp.Callable[[f64["N D"]], dx.Distribution]:
) -> tp.Callable[[Float[Array, "N D"]], dx.Distribution]:
"""Conditional on a set of training data, compute the GP's posterior predictive distribution for a given set of parameters. The returned function can be evaluated at a set of test inputs to compute the corresponding predictive density. Note, to gain predictions on the scale of the original data, the returned distribution will need to be transformed through the likelihood function's inverse link function.
Args:
Expand All @@ -277,7 +279,7 @@ def predict(
Kxx += I(n) * self.jitter
Lx = jnp.linalg.cholesky(Kxx)

def predict_fn(test_inputs: f64["N D"]) -> dx.Distribution:
def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.Distribution:
t = test_inputs
n_test = t.shape[0]
Ktx = cross_covariance(self.prior.kernel, t, x, params["kernel"])
Expand Down Expand Up @@ -306,7 +308,7 @@ def marginal_log_likelihood(
transformations: Dict,
priors: dict = None,
negative: bool = False,
) -> tp.Callable[[dict], f64["1"]]:
) -> tp.Callable[[dict], Float[Array, "1"]]:
"""Compute the marginal log-likelihood function of the Gaussian process. The returned function can then be used for gradient based optimisation of the model's parameters or for model comparison. The implementation given here is general and will work for any likelihood support by GPJax.
Args:
Expand Down
58 changes: 41 additions & 17 deletions gpjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jax.numpy as jnp
from chex import dataclass
from jax import vmap
from jaxtyping import f64
from jaxtyping import Array, Float


##########################################
Expand All @@ -23,7 +23,9 @@ def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)

@abc.abstractmethod
def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs.
Args:
x (jnp.DeviceArray): The left hand argument of the kernel function's call.
Expand All @@ -34,7 +36,7 @@ def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
"""
raise NotImplementedError

def slice_input(self, x: f64["N D"]) -> f64["N Q"]:
def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]:
"""Select the relevant columns of the supplied matrix to be used within the kernel's evaluation.
Args:
x (Array): The matrix or vector that is to be sliced.
Expand Down Expand Up @@ -101,7 +103,9 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
"""A template dictionary of the kernel's parameter set."""
return [kernel._initialise_params(key) for kernel in self.kernel_set]

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
return self.combination_fn(
jnp.stack([k(x, y, p) for k, p in zip(self.kernel_set, params)])
)
Expand Down Expand Up @@ -135,7 +139,9 @@ class RBF(Kernel):
def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma`
.. math::
Expand Down Expand Up @@ -170,7 +176,9 @@ class Matern12(Kernel):
def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma`
.. math::
Expand Down Expand Up @@ -204,7 +212,9 @@ class Matern32(Kernel):
def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma`
.. math::
Expand Down Expand Up @@ -244,7 +254,9 @@ class Matern52(Kernel):
def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with lengthscale parameter :math:`\ell` and variance :math:`\sigma`
.. math::
Expand Down Expand Up @@ -286,7 +298,9 @@ def __post_init__(self):
self.ndims = 1 if not self.active_dims else len(self.active_dims)
self.name = f"Polynomial Degree: {self.degree}"

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\alpha` and variance :math:`\sigma` through
.. math::
Expand Down Expand Up @@ -317,7 +331,7 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
##########################################
@dataclass
class _EigenKernel:
laplacian: f64["N N"]
laplacian: Float[Array, "N N"]


@dataclass
Expand All @@ -330,7 +344,9 @@ def __post_init__(self):
self.evals = evals.reshape(-1, 1)
self.num_vertex = self.laplacian.shape[0]

def __call__(self, x: f64["1 D"], y: f64["1 D"], params: dict) -> f64["1"]:
def __call__(
self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: dict
) -> Float[Array, "1"]:
"""Evaluate the graph kernel on a pair of vertices v_i, v_j.
Args:
Expand Down Expand Up @@ -361,17 +377,23 @@ def _initialise_params(self, key: jnp.DeviceArray) -> Dict:
}


def squared_distance(x: f64["1 D"], y: f64["1 D"]) -> f64["1"]:
def squared_distance(
x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
"""Compute the squared distance between a pair of inputs."""
return jnp.sum((x - y) ** 2)


def euclidean_distance(x: f64["1 D"], y: f64["1 D"]) -> f64["1"]:
def euclidean_distance(
x: Float[Array, "1 D"], y: Float[Array, "1 D"]
) -> Float[Array, "1"]:
"""Compute the l1 norm between a pair of inputs."""
return jnp.sqrt(jnp.maximum(jnp.sum((x - y) ** 2), 1e-36))


def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:
def gram(
kernel: Kernel, inputs: Float[Array, "N D"], params: dict
) -> Float[Array, "N N"]:
"""For a given kernel, compute the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`.
Args:
Expand All @@ -386,8 +408,8 @@ def gram(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:


def cross_covariance(
kernel: Kernel, x: f64["N D"], y: f64["M D"], params: dict
) -> f64["N M"]:
kernel: Kernel, x: Float[Array, "N D"], y: Float[Array, "M D"], params: dict
) -> Float[Array, "N M"]:
"""For a given kernel, compute the :math:`m \times n` gram matrix on an a pair of input matrices with shape :math:`m \times d` and :math:`n \times d` for :math:`d\geq 1`.
Args:
Expand All @@ -402,7 +424,9 @@ def cross_covariance(
return vmap(lambda x1: vmap(lambda y1: kernel(x1, y1, params))(y))(x)


def diagonal(kernel: Kernel, inputs: f64["N D"], params: dict) -> f64["N N"]:
def diagonal(
kernel: Kernel, inputs: Float[Array, "N D"], params: dict
) -> Float[Array, "N N"]:
"""For a given kernel, compute the elementwise diagonal of the :math:`n \times n` gram matrix on an input matrix of shape :math:`n \times d` for :math:`d\geq 1`.
Args:
kernel (Kernel): The kernel for which the variance vector should be computed for.
Expand Down
6 changes: 4 additions & 2 deletions gpjax/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
import jax.scipy as jsp
from chex import dataclass
from jaxtyping import f64
from jaxtyping import Array, Float

from .utils import I

Expand Down Expand Up @@ -107,7 +107,9 @@ def predictive_moment_fn(self) -> Callable:
Callable: A callable object that accepts a mean and variance term from which the predictive random variable is computed.
"""

def moment_fn(mean: f64["N D"], variance: f64["N D"], params: Dict):
def moment_fn(
mean: Float[Array, "N D"], variance: Float[Array, "N D"], params: Dict
):
rv = self.link_function(mean / jnp.sqrt(1 + variance), params)
return rv

Expand Down
8 changes: 4 additions & 4 deletions gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp
from chex import dataclass
from jaxtyping import f64
from jaxtyping import Array, Float


@dataclass(repr=False)
Expand All @@ -14,7 +14,7 @@ class AbstractMeanFunction:
name: Optional[str] = "Mean function"

@abc.abstractmethod
def __call__(self, x: f64["N D"]) -> f64["N Q"]:
def __call__(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]:
"""Evaluate the mean function at the given points. This method is required for all subclasses.
Args:
Expand Down Expand Up @@ -44,7 +44,7 @@ class Zero(AbstractMeanFunction):
output_dim: Optional[int] = 1
name: Optional[str] = "Zero mean function"

def __call__(self, x: f64["N D"], params: dict) -> f64["N Q"]:
def __call__(self, x: Float[Array, "N D"], params: dict) -> Float[Array, "N Q"]:
"""Evaluate the mean function at the given points.
Args:
Expand Down Expand Up @@ -72,7 +72,7 @@ class Constant(AbstractMeanFunction):
output_dim: Optional[int] = 1
name: Optional[str] = "Constant mean function"

def __call__(self, x: f64["N D"], params: Dict) -> f64["N Q"]:
def __call__(self, x: Float[Array, "N D"], params: Dict) -> Float[Array, "N Q"]:
"""Evaluate the mean function at the given points.
Args:
Expand Down
Loading

0 comments on commit 8ac2427

Please sign in to comment.