Skip to content

Commit

Permalink
masked log prob and posterior prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Aug 2, 2023
1 parent a569625 commit 8f01ded
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 4 deletions.
17 changes: 15 additions & 2 deletions gpjax/gaussian_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Float
from jaxtyping import (
Bool,
Float,
)
import tensorflow_probability.substrates.jax as tfp

from gpjax.linops import (
Expand Down Expand Up @@ -149,7 +152,9 @@ def entropy(self) -> ScalarFloat:
self.event_shape[0] * (1.0 + jnp.log(2.0 * jnp.pi)) + self.scale.log_det()
)

def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
def log_prob(
self, y: Float[Array, " N"], mask: Optional[Bool[Array, " N"]] = None
) -> ScalarFloat:
r"""Calculates the log pdf of the multivariate Gaussian.
Args:
Expand All @@ -162,6 +167,14 @@ def log_prob(self, y: Float[Array, " N"]) -> ScalarFloat:
mu = self.loc
sigma = self.scale
n = mu.shape[-1]
if mask is not None:
mask_ = jnp.squeeze(mask)
y = jnp.where(mask_, 0.0, y)
mu = jnp.where(mask_, 0.0, mu)
sigma_masked = jnp.where(mask + mask.T, 0.0, sigma.matrix)
sigma = sigma.replace(
matrix=jnp.where(jnp.diag(mask_), 1 / (2 * jnp.pi), sigma_masked)
)

# diff, y - µ
diff = y - mu
Expand Down
14 changes: 13 additions & 1 deletion gpjax/gps.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def predict(
distribution as a `GaussianDistribution`.
"""
# Unpack training data
x, y, n = train_data.X, train_data.y, train_data.n
x, y, n, mask = train_data.X, train_data.y, train_data.n, train_data.mask

# Unpack test inputs
t, n_test = test_inputs, test_inputs.shape[0]
Expand All @@ -491,11 +491,23 @@ def predict(
# Σ = Kxx + Io²
Sigma = Kxx + identity(n) * obs_noise

if mask is not None:
y = jnp.where(mask, 0.0, y)
mx = jnp.where(mask, 0.0, mx)
Sigma_masked = jnp.where(mask + mask.T, 0.0, Sigma.matrix)
Sigma = Sigma.replace(
matrix=jnp.where(
jnp.diag(jnp.squeeze(mask)), 1 / (2 * jnp.pi), Sigma_masked
)
)

mean_t = self.prior.mean_function(t)
Ktt = self.prior.kernel.gram(t)
Kxt = self.prior.kernel.cross_covariance(x, t)

# Σ⁻¹ Kxt
if mask is not None:
Kxt = jnp.where(mask * jnp.ones((1, n_test), dtype=bool), 0.0, Kxt)
Sigma_inv_Kxt = Sigma.solve(Kxt)

# μt + Ktx (Kxx + Io²)⁻¹ (y - μx)
Expand Down
4 changes: 3 additions & 1 deletion gpjax/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def step(
# p(y | x, θ), where θ are the model hyperparameters:
mll = GaussianDistribution(jnp.atleast_1d(mx.squeeze()), Sigma)

return self.constant * (mll.log_prob(jnp.atleast_1d(y.squeeze())).squeeze())
return self.constant * (
mll.log_prob(jnp.atleast_1d(y.squeeze()), mask=train_data.mask).squeeze()
)


class LogPosteriorDensity(AbstractObjective):
Expand Down

0 comments on commit 8f01ded

Please sign in to comment.