Skip to content

Commit

Permalink
undo dumb move
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Nov 30, 2023
1 parent 0a8ffac commit 2ef250b
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 22 deletions.
34 changes: 16 additions & 18 deletions gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
Optional,
Union,
)
from cola.linalg.inverse.inv import solve
from cola.annotations import PSD
from cola.ops.operators import I_like
import cola
from cola.linalg.decompositions.decompositions import Cholesky
import jax.numpy as jnp
from jax.random import (
Expand Down Expand Up @@ -277,8 +275,8 @@ def predict(self, test_inputs: Num[Array, "N D"]) -> GaussianDistribution:
x = test_inputs
mx = self.mean_function(x)
Kxx = self.kernel.gram(x)
Kxx += I_like(Kxx) * self.jitter
Kxx = PSD(Kxx)
Kxx += cola.ops.I_like(Kxx) * self.jitter
Kxx = cola.PSD(Kxx)

return GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Kxx)

Expand Down Expand Up @@ -524,24 +522,24 @@ def predict(

# Precompute Gram matrix, Kxx, at training inputs, x
Kxx = self.prior.kernel.gram(x)
Kxx += I_like(Kxx) * self.jitter
Kxx += cola.ops.I_like(Kxx) * self.jitter

# Σ = Kxx + Io²
Sigma = Kxx + I_like(Kxx) * obs_noise
Sigma = PSD(Sigma)
Sigma = Kxx + cola.ops.I_like(Kxx) * obs_noise
Sigma = cola.PSD(Sigma)

mean_t = self.prior.mean_function(t)
Ktt = self.prior.kernel.gram(t)
Kxt = self.prior.kernel.cross_covariance(x, t)
Sigma_inv_Kxt = solve(Sigma, Kxt)
Sigma_inv_Kxt = cola.solve(Sigma, Kxt)

# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
mean = mean_t + jnp.matmul(Sigma_inv_Kxt.T, y - mx)

# Ktt - Ktx (Kxx + Io²)⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
covariance = Ktt - jnp.matmul(Kxt.T, Sigma_inv_Kxt)
covariance += I_like(covariance) * self.prior.jitter
covariance = PSD(covariance)
covariance += cola.ops.I_like(covariance) * self.prior.jitter
covariance = cola.PSD(covariance)

return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

Expand Down Expand Up @@ -603,11 +601,11 @@ def sample_approx(
# v = Σ⁻¹ (y + ε - ɸ⍵) for Σ = Kxx + Io² and ε ᯈ N(0, o²)
obs_var = self.likelihood.obs_stddev**2
Kxx = self.prior.kernel.gram(train_data.X) # [N, N]
Sigma = Kxx + I_like(Kxx) * (obs_var + self.jitter) # [N, N]
Sigma = Kxx + cola.ops.I_like(Kxx) * (obs_var + self.jitter) # [N, N]
eps = jnp.sqrt(obs_var) * normal(key, [train_data.n, num_samples]) # [N, B]
y = train_data.y - self.prior.mean_function(train_data.X) # account for mean
Phi = fourier_feature_fn(train_data.X)
canonical_weights = solve(
canonical_weights = cola.solve(
Sigma,
y + eps - jnp.inner(Phi, fourier_weights),
Cholesky(),
Expand Down Expand Up @@ -686,8 +684,8 @@ def predict(

# Precompute lower triangular of Gram matrix, Lx, at training inputs, x
Kxx = kernel.gram(x)
Kxx += I_like(Kxx) * self.prior.jitter
Kxx = PSD(Kxx)
Kxx += cola.ops.I_like(Kxx) * self.prior.jitter
Kxx = cola.PSD(Kxx)
Lx = lower_cholesky(Kxx)

# Unpack test inputs
Expand All @@ -699,7 +697,7 @@ def predict(
mean_t = mean_function(t)

# Lx⁻¹ Kxt
Lx_inv_Kxt = solve(Lx, Ktx.T, Cholesky())
Lx_inv_Kxt = cola.solve(Lx, Ktx.T, Cholesky())

# Whitened function values, wx, corresponding to the inputs, x
wx = self.latent
Expand All @@ -709,8 +707,8 @@ def predict(

# Ktt - Ktx Kxx⁻¹ Kxt, TODO: Take advantage of covariance structure to compute Schur complement more efficiently.
covariance = Ktt - jnp.matmul(Lx_inv_Kxt.T, Lx_inv_Kxt)
covariance += I_like(covariance) * self.prior.jitter
covariance = PSD(covariance)
covariance += cola.ops.I_like(covariance) * self.prior.jitter
covariance = cola.PSD(covariance)

return GaussianDistribution(jnp.atleast_1d(mean.squeeze()), covariance)

Expand Down
3 changes: 1 addition & 2 deletions gpjax/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Num,
)
import tensorflow_probability.substrates.jax.distributions as tfd
from cola.ops.operators import LinearOperator

from gpjax.base import (
Module,
Expand Down Expand Up @@ -61,7 +60,7 @@ def ndims(self):
def cross_covariance(self, x: Num[Array, "N D"], y: Num[Array, "M D"]):
return self.compute_engine.cross_covariance(self, x, y)

def gram(self, x: Num[Array, "N D"]) -> LinearOperator:
def gram(self, x: Num[Array, "N D"]):
return self.compute_engine.gram(self, x)

def slice_input(self, x: Float[Array, "... D"]) -> Float[Array, "... Q"]:
Expand Down
4 changes: 2 additions & 2 deletions gpjax/kernels/computations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from dataclasses import dataclass
import typing as tp

from cola.annotations import PSD
from cola.ops.operators import (
from cola import PSD
from cola.ops import (
Dense,
Diagonal,
LinearOperator,
Expand Down

0 comments on commit 2ef250b

Please sign in to comment.