diff --git a/gpjax/distributions.py b/gpjax/distributions.py index 027ece78..f638cef0 100644 --- a/gpjax/distributions.py +++ b/gpjax/distributions.py @@ -46,6 +46,8 @@ tfd = tfp.distributions +from cola.linalg.decompositions.decompositions import Cholesky + def _check_loc_scale(loc: Optional[Any], scale: Optional[Any]) -> None: r"""Checks that the inputs are correct.""" @@ -156,7 +158,7 @@ def entropy(self) -> ScalarFloat: r"""Calculates the entropy of the distribution.""" return 0.5 * ( self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) - + cola.logdet(self.scale) + + cola.logdet(self.scale, Cholesky(), Cholesky()) ) def log_prob( @@ -190,8 +192,8 @@ def log_prob( # compute the pdf, -1/2[ n log(2π) + log|Σ| + (y - µ)ᵀΣ⁻¹(y - µ) ] return -0.5 * ( n * jnp.log(2.0 * jnp.pi) - + cola.logdet(sigma) - + diff.T @ cola.solve(sigma, diff) + + cola.logdet(sigma, Cholesky(), Cholesky()) + + diff.T @ cola.solve(sigma, diff, Cholesky()) ) def _sample_n(self, key: KeyArray, n: int) -> Float[Array, "n N"]: @@ -346,15 +348,19 @@ def _kl_divergence(q: GaussianDistribution, p: GaussianDistribution) -> ScalarFl # trace term, tr[Σp⁻¹ Σq] = tr[(LpLpᵀ)⁻¹(LqLqᵀ)] = tr[(Lp⁻¹Lq)(Lp⁻¹Lq)ᵀ] = (fr[LqLp⁻¹])² trace = _frobenius_norm_squared( - cola.solve(sqrt_p, sqrt_q.to_dense()) + cola.solve(sqrt_p, sqrt_q.to_dense(), Cholesky()) ) # TODO: Not most efficient, given the `to_dense()` call (e.g., consider diagonal p and q). Need to abstract solving linear operator against another linear operator. # Mahalanobis term, (μp - μq)ᵀ Σp⁻¹ (μp - μq) = tr [(μp - μq)ᵀ [LpLpᵀ]⁻¹ (μp - μq)] = (fr[Lp⁻¹(μp - μq)])² - mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff))) + mahalanobis = jnp.sum(jnp.square(cola.solve(sqrt_p, diff, Cholesky()))) # KL[q(x)||p(x)] = [ [(μp - μq)ᵀ Σp⁻¹ (μp - μq)] - n - log|Σq| + log|Σp| + tr[Σp⁻¹ Σq] ] / 2 return ( - mahalanobis - n_dim - cola.logdet(sigma_q) + cola.logdet(sigma_p) + trace + mahalanobis + - n_dim + - cola.logdet(sigma_q, Cholesky(), Cholesky()) + + cola.logdet(sigma_p, Cholesky(), Cholesky()) + + trace ) / 2.0 diff --git a/gpjax/gps.py b/gpjax/gps.py index b37aa449..638f408c 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -27,6 +27,7 @@ Union, ) import cola +from cola.linalg.decompositions.decompositions import Cholesky from cola.ops import Dense import jax.numpy as jnp from jax.random import ( @@ -540,7 +541,7 @@ def predict( # Σ⁻¹ Kxt if mask is not None: Kxt = jnp.where(mask * jnp.ones((1, n_train), dtype=bool), 0.0, Kxt) - Sigma_inv_Kxt = cola.solve(Sigma, Kxt) + Sigma_inv_Kxt = cola.solve(Sigma, Kxt, Cholesky()) # μt + Ktx (Kxx + Io²)⁻¹ (y - μx) mean = mean_t.flatten() + Sigma_inv_Kxt.T @ (y - mx).flatten() @@ -618,7 +619,9 @@ def sample_approx( y = train_data.y - self.prior.mean_function(train_data.X) # account for mean Phi = fourier_feature_fn(train_data.X) canonical_weights = cola.solve( - Sigma, y + eps - jnp.inner(Phi, fourier_weights) + Sigma, + y + eps - jnp.inner(Phi, fourier_weights), + Cholesky(), ) # [N, B] def sample_fn(test_inputs: Float[Array, "n D"]) -> Float[Array, "n B"]: @@ -707,7 +710,7 @@ def predict( mean_t = mean_function(t) # Lx⁻¹ Kxt - Lx_inv_Kxt = cola.solve(Lx, Ktx.T) + Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky()) # Whitened function values, wx, corresponding to the inputs, x wx = self.latent diff --git a/gpjax/objectives.py b/gpjax/objectives.py index b050b09a..88777214 100644 --- a/gpjax/objectives.py +++ b/gpjax/objectives.py @@ -37,6 +37,8 @@ bound="gpjax.variational_families.AbstractVariationalFamily", # noqa: F821 ) +from cola.linalg.decompositions.decompositions import Cholesky + @dataclass class AbstractObjective(Module): @@ -384,7 +386,7 @@ def step( # # with A and B defined as above. - A = cola.solve(Lz, Kzx) / jnp.sqrt(noise) + A = cola.solve(Lz, Kzx, Cholesky()) / jnp.sqrt(noise) # AAᵀ AAT = jnp.matmul(A, A.T) diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index cd7f0c79..d66fbf19 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -21,6 +21,7 @@ Union, ) import cola +from cola.linalg.decompositions.decompositions import Cholesky import jax.numpy as jnp import jax.scipy as jsp from jaxtyping import Float @@ -202,10 +203,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: mut = mean_function(t) # Lz⁻¹ Kzt - Lz_inv_Kzt = cola.solve(Lz, Kzt) + Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky()) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt) + Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky()) # Ktz Kzz⁻¹ sqrt Ktz_Kzz_inv_sqrt = jnp.matmul(Kzz_inv_Kzt.T, sqrt) @@ -305,7 +306,7 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: mut = mean_function(t) # Lz⁻¹ Kzt - Lz_inv_Kzt = cola.solve(Lz, Kzt) + Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky()) # Ktz Lz⁻ᵀ sqrt Ktz_Lz_invT_sqrt = jnp.matmul(Lz_inv_Kzt.T, sqrt) @@ -459,10 +460,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: mut = mean_function(test_inputs) # Lz⁻¹ Kzt - Lz_inv_Kzt = cola.solve(Lz, Kzt) + Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky()) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt) + Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky()) # Ktz Kzz⁻¹ L Ktz_Kzz_inv_L = jnp.matmul(Kzz_inv_Kzt.T, sqrt) @@ -608,10 +609,10 @@ def predict(self, test_inputs: Float[Array, "N D"]) -> GaussianDistribution: mut = mean_function(t) # Lz⁻¹ Kzt - Lz_inv_Kzt = cola.solve(Lz, Kzt) + Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky()) # Kzz⁻¹ Kzt - Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt) + Kzz_inv_Kzt = cola.solve(Lz.T, Lz_inv_Kzt, Cholesky()) # Ktz Kzz⁻¹ sqrt Ktz_Kzz_inv_sqrt = Kzz_inv_Kzt.T @ sqrt @@ -683,7 +684,7 @@ def predict( Lz = lower_cholesky(Kzz) # Lz⁻¹ Kzx - Lz_inv_Kzx = cola.solve(Lz, Kzx) + Lz_inv_Kzx = cola.solve(Lz, Kzx, Cholesky()) # A = Lz⁻¹ Kzt / o A = Lz_inv_Kzx / self.posterior.likelihood.obs_stddev @@ -701,14 +702,14 @@ def predict( Lz_inv_Kzx_diff = jsp.linalg.cho_solve((L, True), jnp.matmul(Lz_inv_Kzx, diff)) # Kzz⁻¹ Kzx (y - μx) - Kzz_inv_Kzx_diff = cola.solve(Lz.T, Lz_inv_Kzx_diff) + Kzz_inv_Kzx_diff = cola.solve(Lz.T, Lz_inv_Kzx_diff, Cholesky()) Ktt = kernel.gram(t) Kzt = kernel.cross_covariance(z, t) mut = mean_function(t) # Lz⁻¹ Kzt - Lz_inv_Kzt = cola.solve(Lz, Kzt) + Lz_inv_Kzt = cola.solve(Lz, Kzt, Cholesky()) # L⁻¹ Lz⁻¹ Kzt L_inv_Lz_inv_Kzt = jsp.linalg.solve_triangular(L, Lz_inv_Kzt, lower=True)