Skip to content

Commit

Permalink
set all solve and logdeterminants to use Cholesky decomposition
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-dodd committed Oct 27, 2023
1 parent 999f08c commit 88a9667
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
18 changes: 12 additions & 6 deletions gpjax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

tfd = tfp.distributions

from cola.linalg.decompositions.decompositions import Cholesky


def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None:
r"""Checks that the inputs are correct."""
Expand Down Expand Up @@ -156,7 +158,7 @@ def entropy(self) -> ScalarFloat:
r"""Calculates the entropy of the distribution."""
return 0.5 * (
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi))
+ cola.logdet(self.scale)
+ cola.logdet(self.scale, Cholesky(), Cholesky())
)

def log_prob(
Expand Down Expand Up @@ -190,8 +192,8 @@ def log_prob(
# compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ]
return -0.5 * (
n * jnp.log(2.0 * jnp.pi)
+ cola.logdet(sigma)
+ diff.T @ cola.solve(sigma, diff)
+ cola.logdet(sigma, Cholesky(), Cholesky())
+ diff.T @ cola.solve(sigma, diff, Cholesky())
)

def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]:
Expand Down Expand Up @@ -346,15 +348,19 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl

# trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])²
trace = _frobenius_norm_squared(
cola.solve(sqrt_p, sqrt_q.to_dense())
cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky())
) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator.

# Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])²
mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff)))
mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky())))

# KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2
return (
mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace
mahalanobis
- n_dim
- cola.logdet(sigma_q, Cholesky(), Cholesky())
+ cola.logdet(sigma_p, Cholesky(), Cholesky())
+ trace
) / 2.0


Expand Down
9 changes: 6 additions & 3 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Union,
)
import cola
from cola.linalg.decompositions.decompositions import Cholesky
from cola.ops import Dense
import jax.numpy as jnp
from jax.random import (
Expand Down Expand Up @@ -540,7 +541,7 @@ def predict(
# Σ⁻¹ Kxt
if mask is not None:
Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky())

# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten()
Expand Down Expand Up @@ -618,7 +619,9 @@ def sample_approx(
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = fourier_feature_fn(train_data.X)
canonical_weights = cola.solve(
Sigma, y + eps - jnp.inner(Phi, fourier_weights)
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
) # [N, B]

def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]:
Expand Down Expand Up @@ -707,7 +710,7 @@ def predict(
mean_t = mean_function(t)

# Lx⁻¹ Kxt
Lx_inv_Kxt = cola.solve(Lx, Ktx.T)
Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky())

# Whitened function values, wx, corresponding to the inputs, x
wx = self.latent
Expand Down
4 changes: 3 additions & 1 deletion gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
bound="gpjax.variational_families.AbstractVariationalFamily", # noqa: F821
)

from cola.linalg.decompositions.decompositions import Cholesky


@dataclass
class AbstractObjective(Module):
Expand Down Expand Up @@ -384,7 +386,7 @@ def step(
#
# with A and B defined as above.

A = cola.solve(Lz, Kzx) / jnp.sqrt(noise)
A = cola.solve(Lz, Kzx, Cholesky()) / jnp.sqrt(noise)

# AAᵀ
AAT = jnp.matmul(A, A.T)
Expand Down
21 changes: 11 additions & 10 deletions gpjax/variational_families.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Union,
)
import cola
from cola.linalg.decompositions.decompositions import Cholesky
import jax.numpy as jnp
import jax.scipy as jsp
from jaxtyping import Float
Expand Down Expand Up @@ -202,10 +203,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
mut = mean_function(t)

# Lz⁻¹ Kzt
Lz_inv_Kzt = cola.solve(Lz, Kzt)
Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky())

# Kzz⁻¹ Kzt
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt)
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky())

# Ktz Kzz⁻¹ sqrt
Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
Expand Down Expand Up @@ -305,7 +306,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
mut = mean_function(t)

# Lz⁻¹ Kzt
Lz_inv_Kzt = cola.solve(Lz, Kzt)
Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky())

# Ktz Lz⁻ᵀ sqrt
Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt)
Expand Down Expand Up @@ -459,10 +460,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
mut = mean_function(test_inputs)

# Lz⁻¹ Kzt
Lz_inv_Kzt = cola.solve(Lz, Kzt)
Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky())

# Kzz⁻¹ Kzt
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt)
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky())

# Ktz Kzz⁻¹ L
Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt)
Expand Down Expand Up @@ -608,10 +609,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution:
mut = mean_function(t)

# Lz⁻¹ Kzt
Lz_inv_Kzt = cola.solve(Lz, Kzt)
Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky())

# Kzz⁻¹ Kzt
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt)
Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky())

# Ktz Kzz⁻¹ sqrt
Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt
Expand Down Expand Up @@ -683,7 +684,7 @@ def predict(
Lz = lower_cholesky(Kzz)

# Lz⁻¹ Kzx
Lz_inv_Kzx = cola.solve(Lz, Kzx)
Lz_inv_Kzx = cola.solve(Lz, Kzx, Cholesky())

# A = Lz⁻¹ Kzt / o
A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev
Expand All @@ -701,14 +702,14 @@ def predict(
Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff))

# Kzz⁻¹ Kzx (y - μx)
Kzz_inv_Kzx_diff = cola.solve(Lz.T, Lz_inv_Kzx_diff)
Kzz_inv_Kzx_diff = cola.solve(Lz.T, Lz_inv_Kzx_diff, Cholesky())

Ktt = kernel.gram(t)
Kzt = kernel.cross_covariance(z, t)
mut = mean_function(t)

# Lz⁻¹ Kzt
Lz_inv_Kzt = cola.solve(Lz, Kzt)
Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky())

# L⁻¹ Lz⁻¹ Kzt
L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)
Expand Down

0 comments on commit 88a9667

Please sign in to comment.