Skip to content

Commit

Permalink
Merge pull request #370 from JaxGaussianProcesses/cola
Browse files Browse the repository at this point in the history
CoLA integration
  • Loading branch information
daniel-dodd authored Sep 6, 2023
2 parents 4f0e31c + da3cd57 commit c61ff2f
Show file tree
Hide file tree
Showing 49 changed files with 4,428 additions and 5,908 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
matrix:
os: ["ubuntu-latest"]
python-version: ["3.8"]
python-version: ["3.10"]

steps:
# Grap the latest commit from the branch
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
matrix:
# Select the Python versions to test against
os: ["ubuntu-latest", "macos-latest"]
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11"]
fail-fast: true
steps:
- name: Check out the code
Expand Down
779 changes: 779 additions & 0 deletions docs/examples/classification.ipynb

Large diffs are not rendered by default.

15 changes: 8 additions & 7 deletions docs/examples/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,17 @@
# datapoints below.

# %%
import cola
from gpjax.lower_cholesky import lower_cholesky

gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
jitter = 1e-6

# Compute (latent) function value map estimates at training points:
Kxx = opt_posterior.prior.kernel.gram(x)
Kxx += identity_matrix(D.n) * jitter
Lx = Kxx.to_root()
Kxx = cola.PSD(Kxx)
Lx = lower_cholesky(Kxx)
f_hat = Lx @ opt_posterior.latent

# Negative Hessian, H = -∇²p_tilde(y|f):
Expand Down Expand Up @@ -250,16 +254,13 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
Kxx = opt_posterior.prior.kernel.gram(x)
Kxx += identity_matrix(D.n) * jitter
Lx = Kxx.to_root()

# Lx⁻¹ Kxt
Lx_inv_Ktx = Lx.solve(Kxt)
Kxx = cola.PSD(Kxx)

# Kxx⁻¹ Kxt
Kxx_inv_Ktx = Lx.T.solve(Lx_inv_Ktx)
Kxx_inv_Kxt = cola.solve(Kxx, Kxt)

# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Ktx.T, H_inv), Kxx_inv_Ktx)
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)

mean = map_latent_dist.mean()
covariance = map_latent_dist.covariance() + laplace_cov_term
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/constructing_new_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
# like our RBF kernel to act on the first, second and fourth dimensions.

# %%
slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3], lengthscale = jnp.ones((3,)))
slice_kernel = gpx.kernels.RBF(active_dims=[0, 1, 3], lengthscale=jnp.ones((3,)))

# %% [markdown]
#
Expand Down
2 changes: 1 addition & 1 deletion gpjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
RFF,
AbstractKernel,
BasisFunctionComputation,
ConstantDiagonalKernelComputation,
CatKernel,
ConstantDiagonalKernelComputation,
DenseKernelComputation,
DiagonalKernelComputation,
EigenKernelComputation,
Expand Down
44 changes: 23 additions & 21 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
dataclass,
fields,
)
from functools import singledispatch

from beartype.typing import (
Dict,
Union,
)
from jaxlib.xla_extension import PjitFunction
from plum import dispatch

from gpjax.kernels import (
RFF,
Expand All @@ -26,8 +26,6 @@
NonConjugateMLL,
)

MaternKernels = Union[Matern12, Matern32, Matern52]
MLLs = Union[ConjugateMLL, NonConjugateMLL, LogPosteriorDensity]
CitationType = Union[str, Dict[str, str]]


Expand Down Expand Up @@ -89,24 +87,26 @@ class BookCitation(AbstractCitation):
####################
# Default citation
####################
@dispatch
def cite(tree) -> NullCitation:
@singledispatch
def cite(tree) -> AbstractCitation:
return NullCitation()


####################
# Default citation
####################
@dispatch
def cite(tree: PjitFunction) -> JittedFnCitation:
@cite.register(PjitFunction)
def _(tree):
return JittedFnCitation()


####################
# Kernel citations
####################
@dispatch
def cite(tree: MaternKernels) -> PhDThesisCitation:
@cite.register(Matern12)
@cite.register(Matern32)
@cite.register(Matern52)
def _(tree) -> PhDThesisCitation:
citation = PhDThesisCitation(
citation_key="matern1960SpatialV",
authors="Bertil Matérn",
Expand All @@ -121,8 +121,8 @@ def cite(tree: MaternKernels) -> PhDThesisCitation:
return citation


@dispatch
def cite(tree: ArcCosine) -> PaperCitation:
@cite.register(ArcCosine)
def _(_) -> PaperCitation:
return PaperCitation(
citation_key="cho2009kernel",
authors="Cho, Youngmin and Saul, Lawrence",
Expand All @@ -132,8 +132,8 @@ def cite(tree: ArcCosine) -> PaperCitation:
)


@dispatch
def cite(tree: GraphKernel) -> PaperCitation:
@cite.register(GraphKernel)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="borovitskiy2021matern",
title="Matérn Gaussian Processes on Graphs",
Expand All @@ -146,8 +146,8 @@ def cite(tree: GraphKernel) -> PaperCitation:
)


@dispatch
def cite(tree: RFF) -> PaperCitation:
@cite.register(RFF)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="rahimi2007random",
authors="Rahimi, Ali and Recht, Benjamin",
Expand All @@ -161,8 +161,10 @@ def cite(tree: RFF) -> PaperCitation:
####################
# Objective citations
####################
@dispatch
def cite(tree: MLLs) -> BookCitation:
@cite.register(ConjugateMLL)
@cite.register(NonConjugateMLL)
@cite.register(LogPosteriorDensity)
def _(tree) -> BookCitation:
return BookCitation(
citation_key="rasmussen2006gaussian",
title="Gaussian Processes for Machine Learning",
Expand All @@ -173,8 +175,8 @@ def cite(tree: MLLs) -> BookCitation:
)


@dispatch
def cite(tree: CollapsedELBO) -> PaperCitation:
@cite.register(CollapsedELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="titsias2009variational",
title="Variational learning of inducing variables in sparse Gaussian processes",
Expand All @@ -184,8 +186,8 @@ def cite(tree: CollapsedELBO) -> PaperCitation:
)


@dispatch
def cite(tree: ELBO) -> PaperCitation:
@cite.register(ELBO)
def _(tree) -> PaperCitation:
return PaperCitation(
citation_key="hensman2013gaussian",
title="Gaussian Processes for Big Data",
Expand Down
58 changes: 34 additions & 24 deletions gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
Optional,
Tuple,
)
import cola
from cola.ops import (
Dense,
Identity,
)
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
Expand All @@ -28,10 +33,7 @@
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.linops import (
IdentityLinearOperator,
LinearOperator,
)
from gpjax.lower_cholesky import lower_cholesky
from gpjax.typing import (
Array,
KeyArray,
Expand All @@ -49,15 +51,15 @@ def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
if loc is not None and loc.ndim < 1:
raise ValueError("The parameter `loc` must have at least one dimension.")

if scale is not None and scale.ndim < 2:
if scale is not None and len(scale.shape) < 2: # scale.ndim < 2:
raise ValueError(
"The `scale` must have at least two dimensions, but "
f"`scale.shape = {scale.shape}`."
)

if scale is not None and not isinstance(scale, LinearOperator):
if scale is not None and not isinstance(scale, cola.LinearOperator):
raise ValueError(
f"scale must be a LinearOperator or a JAX array, but got {type(scale)}"
f"The `scale` must be a cola.LinearOperator but got {type(scale)}"
)

if scale is not None and (scale.shape[-1] != scale.shape[-2]):
Expand All @@ -79,7 +81,7 @@ class GaussianDistribution(tfd.Distribution):
Args:
loc (Optional[Float[Array, " N"]]): The mean of the distribution. Defaults to None.
scale (Optional[LinearOperator]): The scale matrix of the distribution. Defaults to None.
scale (Optional[cola.LinearOperator]): The scale matrix of the distribution. Defaults to None.
Returns
-------
Expand All @@ -94,7 +96,7 @@ class GaussianDistribution(tfd.Distribution):
def __init__(
self,
loc: Optional[Float[Array, " N"]] = None,
scale: Optional[LinearOperator] = None,
scale: Optional[cola.LinearOperator] = None,
) -> None:
r"""Initialises the distribution."""
_check_loc_scale(loc, scale)
Expand All @@ -112,10 +114,10 @@ def __init__(

# If not specified, set the scale to the identity matrix.
if scale is None:
scale = IdentityLinearOperator(num_dims)
scale = Identity(shape=(num_dims, num_dims), dtype=loc.dtype)

self.loc = loc
self.scale = scale
self.scale = cola.PSD(scale)

def mean(self) -> Float[Array, " N"]:
r"""Calculates the mean."""
Expand All @@ -135,11 +137,11 @@ def covariance(self) -> Float[Array, "N N"]:

def variance(self) -> Float[Array, " N"]:
r"""Calculates the variance."""
return self.scale.diagonal()
return cola.diag(self.scale)

def stddev(self) -> Float[Array, " N"]:
r"""Calculates the standard deviation."""
return jnp.sqrt(self.scale.diagonal())
return jnp.sqrt(cola.diag(self.scale))

@property
def event_shape(self) -> Tuple:
Expand All @@ -149,7 +151,10 @@ def event_shape(self) -> Tuple:
def entropy(self) -> ScalarFloat:
r"""Calculates the entropy of the distribution."""
return 0.5 * (
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det()
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
+ cola.logdet(
self.scale, method="dense"
) # <--- Seems to be an issue with CoLA!
)

def log_prob(
Expand All @@ -168,20 +173,23 @@ def log_prob(
mu = self.loc
sigma = self.scale
n = mu.shape[-1]

if mask is not None:
y = jnp.where(mask, 0.0, y)
mu = jnp.where(mask, 0.0, mu)
sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.matrix)
sigma = sigma.replace(
matrix=jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked)
sigma_masked = jnp.where(mask[None] + mask[:, None], 0.0, sigma.to_dense())
sigma = cola.PSD(
Dense(jnp.where(jnp.diag(mask), 1 / (2 * jnp.pi), sigma_masked))
)

# diff, y - µ
diff = y - mu

# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
return -0.5 * (
n * jnp.log(2.0 * jnp.pi) + sigma.log_det() + diff.T @ sigma.solve(diff)
n * jnp.log(2.0 * jnp.pi)
+ cola.logdet(sigma, method="dense") # <--- Seems to be an issue with CoLA!
+ diff.T @ cola.solve(sigma, diff)
)

def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Expand All @@ -195,7 +203,7 @@ def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Float[Array, "n N"]: The samples.
"""
# Obtain covariance root.
sqrt = self.scale.to_root()
sqrt = lower_cholesky(self.scale)

# Gather n samples from standard normal distribution Z = [z₁, ..., zₙ]ᵀ.
Z = jr.normal(key, shape=(n, *self.event_shape))
Expand Down Expand Up @@ -263,24 +271,26 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl
sigma_p = p.scale

# Find covariance roots.
sqrt_p = sigma_p.to_root()
sqrt_q = sigma_q.to_root()
sqrt_p = lower_cholesky(sigma_p)
sqrt_q = lower_cholesky(sigma_q)

# diff, μp - μq
diff = mu_p - mu_q

# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
trace = _frobenius_norm_squared(
sqrt_p.solve(sqrt_q.to_dense())
cola.solve(sqrt_p, sqrt_q.to_dense())
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.

# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
mahalanobis = jnp.sum(
jnp.square(sqrt_p.solve(diff))
jnp.square(cola.solve(sqrt_p, diff))
) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.

# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
return (mahalanobis - n_dim - sigma_q.log_det() + sigma_p.log_det() + trace) / 2.0
return (
mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace
) / 2.0


__all__ = [
Expand Down
Loading

0 comments on commit c61ff2f

Please sign in to comment.