Skip to content

Commit

Permalink
Merge pull request #358 from Thomas-Christie/zero-mean-fix
Browse files Browse the repository at this point in the history
Fix bug in zero mean function and add test
  • Loading branch information
Thomas-Christie authored Aug 25, 2023
2 parents c5eb47b + fe8bde4 commit 48706eb
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
7 changes: 6 additions & 1 deletion gpjax/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from gpjax.kernels.approximations import RFF
from gpjax.kernels.base import (
AbstractKernel,
Constant,
ProductKernel,
SumKernel,
)
Expand All @@ -27,7 +28,10 @@
DiagonalKernelComputation,
EigenKernelComputation,
)
from gpjax.kernels.non_euclidean import GraphKernel, CatKernel
from gpjax.kernels.non_euclidean import (
CatKernel,
GraphKernel,
)
from gpjax.kernels.nonstationary import (
ArcCosine,
Linear,
Expand All @@ -47,6 +51,7 @@
__all__ = [
"AbstractKernel",
"ArcCosine",
"Constant",
"RBF",
"GraphKernel",
"CatKernel",
Expand Down
12 changes: 11 additions & 1 deletion gpjax/mean_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,17 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
return jnp.ones((x.shape[0], 1)) * self.constant


@dataclasses.dataclass
class Zero(Constant):
r"""Zero mean function.
The zero mean function. This function returns a zero scalar value for all
inputs. Unlike the Constant mean function, the constant scalar zero is fixed, and
cannot be treated as a model hyperparameter and learned during training.
"""
constant: Float[Array, "1"] = static_field(jnp.array([0.0]), init=False)


@dataclasses.dataclass
class CombinationMeanFunction(AbstractMeanFunction):
r"""A base class for products or sums of AbstractMeanFunctions."""
Expand Down Expand Up @@ -199,4 +210,3 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
ProductMeanFunction = partial(
CombinationMeanFunction, operator=partial(jnp.sum, axis=0)
)
Zero = partial(Constant, constant=jnp.array([0.0]))
53 changes: 53 additions & 0 deletions tests/test_mean_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
# Enable Float64 for more stable matrix inversions.
from jax import config

config.update("jax_enable_x64", True)


import jax
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import (
Array,
Float,
)
import optax as ox
import pytest

import gpjax as gpx
from gpjax.mean_functions import (
AbstractMeanFunction,
Constant,
Zero,
)


Expand Down Expand Up @@ -40,3 +51,45 @@ def test_constant(constant: Float[Array, " Q"]) -> None:
assert (
mf(jnp.array([[1.0, 2.0], [3.0, 4.0]])) == jnp.array([constant, constant])
).all()


def test_zero_mean_remains_zero() -> None:
key = jr.PRNGKey(123)

x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1))
y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean
D = gpx.Dataset(X=x, y=y)

kernel = gpx.kernels.Constant(constant=jnp.array(0.0))
kernel = kernel.replace_trainable(
constant=False
) # Prevent kernel from modelling non-zero mean
meanf = Zero()
prior = gpx.Prior(mean_function=meanf, kernel=kernel)
likelihood = gpx.Gaussian(num_datapoints=D.n, obs_noise=jnp.array(1e-6))
likelihood = likelihood.replace_trainable(obs_noise=False)
posterior = prior * likelihood

negative_mll = gpx.objectives.ConjugateMLL(negative=True)
opt_posterior, _ = gpx.fit(
model=posterior,
objective=negative_mll,
train_data=D,
optim=ox.adam(learning_rate=0.5),
num_iters=1000,
safe=True,
key=key,
)

assert opt_posterior.prior.mean_function.constant == 0.0


def test_zero_mean_pytree_no_leaves():
zero_mean = Zero()
leaves = jax.tree_util.tree_leaves(zero_mean)
assert len(leaves) == 0


def test_initialising_zero_mean_with_constant_raises_error():
with pytest.raises(TypeError):
Zero(constant=jnp.array([1.0]))

0 comments on commit 48706eb

Please sign in to comment.