Skip to content

Commit

Permalink
Merge pull request #405 from JaxGaussianProcesses/bump_cola
Browse files Browse the repository at this point in the history
bump cola to v0.0.5
  • Loading branch information
thomaspinder authored Oct 28, 2023
2 parents 70838a8 + 392b3da commit ae9cfa3
Show file tree
Hide file tree
Showing 15 changed files with 1,802 additions and 1,811 deletions.
4 changes: 2 additions & 2 deletions docs/examples/collapsed_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@
full_rank_model = gpx.Prior(mean_function=gpx.Zero(), kernel=gpx.RBF()) * gpx.Gaussian(
num_datapoints=D.n
)
negative_mll = jit(gpx.ConjugateMLL(negative=True))
negative_mll = jit(gpx.ConjugateMLL(negative=True).step)
# %timeit negative_mll(full_rank_model, D).block_until_ready()

# %%
negative_elbo = jit(gpx.CollapsedELBO(negative=True))
negative_elbo = jit(gpx.CollapsedELBO(negative=True).step)
# %timeit negative_elbo(q, D).block_until_ready()

# %% [markdown]
Expand Down
8 changes: 5 additions & 3 deletions docs/examples/graph_kernels.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
# %% [markdown]
# # Graph Kernels
#
Expand Down Expand Up @@ -119,7 +120,8 @@
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
sm.set_array([])
cbar = plt.colorbar(sm)
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

# %% [markdown]
#
Expand Down Expand Up @@ -201,8 +203,8 @@
sm = plt.cm.ScalarMappable(
cmap=plt.cm.inferno, norm=plt.Normalize(vmin=vmin, vmax=vmax)
)
sm.set_array([])
cbar = plt.colorbar(sm)
ax = plt.gca()
cbar = plt.colorbar(sm, ax=ax)

# %% [markdown]
#
Expand Down
35 changes: 15 additions & 20 deletions gpjax/citation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@
NonConjugateMLL,
)

CitationType = Union[str, Dict[str, str]]
CitationType = Union[None, str, Dict[str, str]]


@dataclass(repr=False)
class AbstractCitation:
citation_key: str = None
authors: str = None
title: str = None
year: str = None
citation_key: Union[str, None] = None
authors: Union[str, None] = None
title: Union[str, None] = None
year: Union[str, None] = None

def as_str(self) -> str:
citation_str = f"@{self.citation_type}{{{self.citation_key},"
Expand All @@ -64,29 +64,24 @@ def __str__(self) -> str:
)


class JittedFnCitation(AbstractCitation):
def __str__(self) -> str:
return "Citation not available for jitted objects."


@dataclass
class PhDThesisCitation(AbstractCitation):
school: str = None
institution: str = None
citation_type: str = "phdthesis"
school: Union[str, None] = None
institution: Union[str, None] = None
citation_type: CitationType = "phdthesis"


@dataclass
class PaperCitation(AbstractCitation):
booktitle: str = None
citation_type: str = "inproceedings"
booktitle: Union[str, None] = None
citation_type: CitationType = "inproceedings"


@dataclass
class BookCitation(AbstractCitation):
publisher: str = None
volume: str = None
citation_type: str = "book"
publisher: Union[str, None] = None
volume: Union[str, None] = None
citation_type: CitationType = "book"


####################
Expand All @@ -101,8 +96,8 @@ def cite(tree) -> AbstractCitation:
# Default citation
####################
@cite.register(PjitFunction)
def _(tree):
return JittedFnCitation()
def _(tree) -> None:
raise RuntimeError("Citation not available for jitted objects.")


####################
Expand Down
31 changes: 17 additions & 14 deletions gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from cola.ops import (
Dense,
Identity,
LinearOperator,
)
from jax import vmap
import jax.numpy as jnp
Expand All @@ -45,6 +46,8 @@

tfd = tfp.distributions

from cola.linalg.decompositions.decompositions import Cholesky


def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
r"""Checks that the inputs are correct."""
Expand All @@ -60,9 +63,9 @@ def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
f"`scale.shape = {scale.shape}`."
)

if scale is not None and not isinstance(scale, cola.LinearOperator):
if scale is not None and not isinstance(scale, LinearOperator):
raise ValueError(
f"The `scale` must be a cola.LinearOperator but got {type(scale)}"
f"The `scale` must be a CoLA LinearOperator but got {type(scale)}"
)

if scale is not None and (scale.shape[-1] != scale.shape[-2]):
Expand All @@ -84,7 +87,7 @@ class GaussianDistribution(tfd.Distribution):
Args:
loc (Optional[Float[Array, " N"]]): The mean of the distribution. Defaults to None.
scale (Optional[cola.LinearOperator]): The scale matrix of the distribution. Defaults to None.
scale (Optional[LinearOperator]): The scale matrix of the distribution. Defaults to None.
Returns
-------
Expand All @@ -99,7 +102,7 @@ class GaussianDistribution(tfd.Distribution):
def __init__(
self,
loc: Optional[Float[Array, " N"]] = None,
scale: Optional[cola.LinearOperator] = None,
scale: Optional[LinearOperator] = None,
) -> None:
r"""Initialises the distribution."""
_check_loc_scale(loc, scale)
Expand Down Expand Up @@ -155,9 +158,7 @@ def entropy(self) -> ScalarFloat:
r"""Calculates the entropy of the distribution."""
return 0.5 * (
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
+ cola.logdet(
self.scale, method="dense"
) # <--- Seems to be an issue with CoLA!
+ cola.logdet(self.scale, Cholesky(), Cholesky())
)

def log_prob(
Expand Down Expand Up @@ -191,8 +192,8 @@ def log_prob(
# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
return -0.5 * (
n * jnp.log(2.0 * jnp.pi)
+ cola.logdet(sigma, method="dense") # <--- Seems to be an issue with CoLA!
+ diff.T @ cola.solve(sigma, diff)
+ cola.logdet(sigma, Cholesky(), Cholesky())
+ diff.T @ cola.solve(sigma, diff, Cholesky())
)

def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Expand Down Expand Up @@ -347,17 +348,19 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl

# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
trace = _frobenius_norm_squared(
cola.solve(sqrt_p, sqrt_q.to_dense())
cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky())
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.

# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
mahalanobis = jnp.sum(
jnp.square(cola.solve(sqrt_p, diff))
) # TODO: Need to improve this. Perhaps add a Mahalanobis method to ``LinearOperator``s.
mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky())))

# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
return (
mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace
mahalanobis
- n_dim
- cola.logdet(sigma_q, Cholesky(), Cholesky())
+ cola.logdet(sigma_p, Cholesky(), Cholesky())
+ trace
) / 2.0


Expand Down
49 changes: 26 additions & 23 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
# ==============================================================================

# from __future__ import annotations
from abc import abstractmethod
from dataclasses import (
dataclass,
Expand All @@ -25,8 +24,10 @@
Any,
Callable,
Optional,
Union,
)
import cola
from cola.linalg.decompositions.decompositions import Cholesky
from cola.ops import Dense
import jax.numpy as jnp
from jax.random import (
Expand Down Expand Up @@ -152,17 +153,17 @@ class Prior(AbstractPrior):
```
"""

@overload
def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
...
# @overload
# def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...

@overload
def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...
# @overload
# def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...

@overload
def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...
# @overload
# def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...

def __mul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -198,17 +199,17 @@ def __mul__(self, other):
"""
return construct_posterior(prior=self, likelihood=other)

@overload
def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
...
# @overload
# def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
# ...

@overload
def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
...
# @overload
# def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
# ...

@overload
def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
...
# @overload
# def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
# ...

def __rmul__(self, other):
r"""Combine the prior with a likelihood to form a posterior distribution.
Expand Down Expand Up @@ -540,7 +541,7 @@ def predict(
# Σ⁻¹ Kxt
if mask is not None:
Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky())

# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten()
Expand Down Expand Up @@ -618,7 +619,9 @@ def sample_approx(
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = fourier_feature_fn(train_data.X)
canonical_weights = cola.solve(
Sigma, y + eps - jnp.inner(Phi, fourier_weights)
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
Expand Down Expand Up @@ -656,7 +659,7 @@ class NonConjugatePosterior(AbstractPosterior):
from, or optimise an approximation to, the posterior distribution.
"""

latent: Float[Array, "N 1"] = param_field(None)
latent: Union[Float[Array, "N 1"], None] = param_field(None)
key: KeyArray = static_field(PRNGKey(42))

def __post_init__(self):
Expand Down Expand Up @@ -707,7 +710,7 @@ def predict(
mean_t = mean_function(t)

# Lx⁻¹ Kxt
Lx_inv_Kxt = cola.solve(Lx, Ktx.T)
Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky())

# Whitened function values, wx, corresponding to the inputs, x
wx = self.latent
Expand Down
21 changes: 14 additions & 7 deletions gpjax/integrators.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import (
TypeVar,
Union,
)

from beartype.typing import Callable
import jax.numpy as jnp
from jaxtyping import Float
import numpy as np

import gpjax
from gpjax.typing import Array

if TYPE_CHECKING:
import gpjax.likelihoods
Likelihood = TypeVar(
"Likelihood",
bound=Union["gpjax.likelihoods.AbstractLikelihood", None], # noqa: F821
)
Gaussian = TypeVar("Gaussian", bound="gpjax.likelihoods.Gaussian") # noqa: F821


@dataclass
Expand All @@ -24,7 +31,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Integrate a function with respect to a Gaussian distribution.
Expand All @@ -47,7 +54,7 @@ def __call__(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Integrate a function with respect to a Gaussian distribution.
Expand Down Expand Up @@ -86,7 +93,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.AbstractLikelihood" = None,
likelihood: Likelihood,
) -> Float[Array, " N"]:
r"""Compute a quadrature integral.
Expand Down Expand Up @@ -127,7 +134,7 @@ def integrate(
y: Float[Array, "N D"],
mean: Float[Array, "N D"],
variance: Float[Array, "N D"],
likelihood: "gpjax.likelihoods.Gaussian" = None,
likelihood: Gaussian,
) -> Float[Array, " N"]:
r"""Compute a Gaussian integral.
Expand Down
Loading

0 comments on commit ae9cfa3

Please sign in to comment.