diff --git a/gpjax/kernels/__init__.py b/gpjax/kernels/__init__.py index 178aeba1..3e01404e 100644 --- a/gpjax/kernels/__init__.py +++ b/gpjax/kernels/__init__.py @@ -17,6 +17,7 @@ from gpjax.kernels.approximations import RFF from gpjax.kernels.base import ( AbstractKernel, + Constant, ProductKernel, SumKernel, ) @@ -27,7 +28,10 @@ DiagonalKernelComputation, EigenKernelComputation, ) -from gpjax.kernels.non_euclidean import GraphKernel, CatKernel +from gpjax.kernels.non_euclidean import ( + CatKernel, + GraphKernel, +) from gpjax.kernels.nonstationary import ( ArcCosine, Linear, @@ -47,6 +51,7 @@ __all__ = [ "AbstractKernel", "ArcCosine", + "Constant", "RBF", "GraphKernel", "CatKernel", diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 13e6fd7b..61d72519 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -150,6 +150,17 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]: return jnp.ones((x.shape[0], 1)) * self.constant +@dataclasses.dataclass +class Zero(Constant): + r"""Zero mean function. + + The zero mean function. This function returns a zero scalar value for all + inputs. Unlike the Constant mean function, the constant scalar zero is fixed, and + cannot be treated as a model hyperparameter and learned during training. + """ + constant: Float[Array, "1"] = static_field(jnp.array([0.0]), init=False) + + @dataclasses.dataclass class CombinationMeanFunction(AbstractMeanFunction): r"""A base class for products or sums of AbstractMeanFunctions.""" @@ -199,4 +210,3 @@ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]: ProductMeanFunction = partial( CombinationMeanFunction, operator=partial(jnp.sum, axis=0) ) -Zero = partial(Constant, constant=jnp.array([0.0])) diff --git a/tests/test_mean_functions.py b/tests/test_mean_functions.py index b4740b15..d4a660c3 100644 --- a/tests/test_mean_functions.py +++ b/tests/test_mean_functions.py @@ -1,13 +1,24 @@ +# Enable Float64 for more stable matrix inversions. +from jax import config + +config.update("jax_enable_x64", True) + + +import jax import jax.numpy as jnp +import jax.random as jr from jaxtyping import ( Array, Float, ) +import optax as ox import pytest +import gpjax as gpx from gpjax.mean_functions import ( AbstractMeanFunction, Constant, + Zero, ) @@ -40,3 +51,45 @@ def test_constant(constant: Float[Array, " Q"]) -> None: assert ( mf(jnp.array([[1.0, 2.0], [3.0, 4.0]])) == jnp.array([constant, constant]) ).all() + + +def test_zero_mean_remains_zero() -> None: + key = jr.PRNGKey(123) + + x = jr.uniform(key=key, minval=0, maxval=1, shape=(20, 1)) + y = jnp.full((20, 1), 50, dtype=jnp.float64) # Dataset with non-zero mean + D = gpx.Dataset(X=x, y=y) + + kernel = gpx.kernels.Constant(constant=jnp.array(0.0)) + kernel = kernel.replace_trainable( + constant=False + ) # Prevent kernel from modelling non-zero mean + meanf = Zero() + prior = gpx.Prior(mean_function=meanf, kernel=kernel) + likelihood = gpx.Gaussian(num_datapoints=D.n, obs_noise=jnp.array(1e-6)) + likelihood = likelihood.replace_trainable(obs_noise=False) + posterior = prior * likelihood + + negative_mll = gpx.objectives.ConjugateMLL(negative=True) + opt_posterior, _ = gpx.fit( + model=posterior, + objective=negative_mll, + train_data=D, + optim=ox.adam(learning_rate=0.5), + num_iters=1000, + safe=True, + key=key, + ) + + assert opt_posterior.prior.mean_function.constant == 0.0 + + +def test_zero_mean_pytree_no_leaves(): + zero_mean = Zero() + leaves = jax.tree_util.tree_leaves(zero_mean) + assert len(leaves) == 0 + + +def test_initialising_zero_mean_with_constant_raises_error(): + with pytest.raises(TypeError): + Zero(constant=jnp.array([1.0]))