diff --git a/tests/test_kernels/test_nonstationary.py b/tests/test_kernels/test_nonstationary.py index 3606f117..fc4de81f 100644 --- a/tests/test_kernels/test_nonstationary.py +++ b/tests/test_kernels/test_nonstationary.py @@ -13,15 +13,20 @@ # limitations under the License. # ============================================================================== -from itertools import permutations +from itertools import permutations, product +from dataclasses import is_dataclass +import jax import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu import pytest +import tensorflow_probability.substrates.jax.bijectors as tfb from jax.config import config +from typing import List from gpjax.kernels.base import AbstractKernel +from gpjax.kernels.computations import DenseKernelComputation from gpjax.kernels.nonstationary import Linear, Polynomial from gpjax.linops import LinearOperator, identity @@ -31,135 +36,132 @@ _jitter = 1e-6 -@pytest.mark.parametrize( - "kernel", - [ - Linear(), - Polynomial(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: - # Gram constructor static method: - kernel.gram - - # Inputs x: - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Test gram matrix: - Kxx = kernel.gram(x) - assert isinstance(Kxx, LinearOperator) - assert Kxx.shape == (n, n) - - -@pytest.mark.parametrize( - "kernel", - [ - Linear(), - Polynomial(), - ], -) -@pytest.mark.parametrize("num_a", [1, 2, 5]) -@pytest.mark.parametrize("num_b", [1, 2, 5]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance( - kernel: AbstractKernel, num_a: int, num_b: int, dim: int -) -> None: - # Inputs a, b: - a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) - b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(a, b) - assert isinstance(Kab, jnp.ndarray) - assert Kab.shape == (num_a, num_b) - - -@pytest.mark.parametrize("kern", [Linear, Polynomial]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0]) -@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def( - kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int -) -> None: - kern = kern(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - - if isinstance(kern, Polynomial): - kern = kern.replace(shift=shift, variance=sigma) - else: - kern = kern.replace(variance=sigma) - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("degree", [1, 2, 3]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0]) -@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_polynomial( - degree: int, dim: int, variance: float, shift: float, n: int -) -> None: - # Define inputs - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Define kernel - kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) - - # # Check name - # assert kern.name == f"Polynomial Degree: {degree}" - - # Initialise parameters - kern = kern.replace(shift=kern.shift * shift, variance=kern.variance * variance) - - # Compute gram matrix - Kxx = kern.gram(x) - - # Check shapes - assert Kxx.shape[0] == x.shape[0] - assert Kxx.shape[0] == Kxx.shape[1] - - # Test positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0).all() - - -@pytest.mark.parametrize( - "kernel", - [Linear, Polynomial], -) -def test_active_dim(kernel: AbstractKernel) -> None: - dim_list = [0, 1, 2, 3] - perm_length = 2 - dim_pairs = list(permutations(dim_list, r=perm_length)) - n_dims = len(dim_list) - - # Generate random inputs - x = jr.normal(_initialise_key, shape=(20, n_dims)) - - for dp in dim_pairs: - # Take slice of x - slice = x[..., dp] - - # Define kernels - ad_kern = kernel(active_dims=dp) - manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - - # Compute gram matrices - ad_Kxx = ad_kern.gram(x) - manual_Kxx = manual_kern.gram(slice) - - # Test gram matrices are equal - assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) +class BaseTestKernel: + """A base class that contains all tests applied on non-stationary kernels.""" + + kernel: AbstractKernel + default_compute_engine: type + static_fields: List[str] + + def pytest_generate_tests(self, metafunc): + """This is called automatically by pytest""" + + # function for pretty test name + id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()]) + + # get arguments for the test function + funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) + if funcarglist is None: + return + else: + # equivalent of pytest.mark.parametrize applied on the metafunction + metafunc.parametrize("fields", funcarglist, ids=id_func) + + @pytest.mark.parametrize("dim", [None, 1, 3], ids=lambda x: f"dim={x}") + def test_initialization(self, fields: dict, dim: int) -> None: + + # Check that kernel is a dataclass + assert is_dataclass(self.kernel) + + # Input fields as JAX arrays + fields = {k: jnp.array([v]) for k, v in fields.items()} + + # Test number of dimensions + if dim is None: + kernel: AbstractKernel = self.kernel(**fields) + assert kernel.ndims == 1 + else: + kernel: AbstractKernel = self.kernel( + active_dims=[i for i in range(dim)], **fields + ) + assert kernel.ndims == dim + + # Check default compute engine + assert kernel.compute_engine == self.default_compute_engine + + # Check properties + for field, value in fields.items(): + assert getattr(kernel, field) == value + + # Test that pytree returns param_field objects (and not static_field) + leaves = jtu.tree_leaves(kernel) + assert len(leaves) == len(set(fields) - set(self.static_fields)) + + # Test dtype of params + for v in leaves: + assert v.dtype == jnp.float64 + + # Check meta leaves + meta = kernel._pytree__meta + assert not any(f in meta.keys() for f in self.static_fields) + assert list(meta.keys()) == sorted(set(fields) - set(self.static_fields)) + + for field in meta: + + # Bijectors + if field in ["variance", "shift"]: + assert isinstance(meta[field]["bijector"], tfb.Softplus) + + # Trainability state + assert meta[field]["trainable"] == True + + # Test kernel call + x = jnp.linspace(0.0, 1.0, 10 * kernel.ndims).reshape(10, kernel.ndims) + jax.vmap(kernel)(x, x) + + @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") + def test_gram(self, dim: int, n: int) -> None: + + # Initialise kernel + kernel: AbstractKernel = self.kernel() + + # Gram constructor static method + kernel.gram + + # Inputs + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Test gram matrix + Kxx = kernel.gram(x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0) + + @pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}") + @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") + @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") + def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: + + # Initialise kernel + kernel: AbstractKernel = self.kernel() + + # Inputs + a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) + b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + + # Test cross-covariance + Kab = kernel.cross_covariance(a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (n_a, n_b) + + +prod = lambda inp: [dict(zip(inp.keys(), values)) for values in product(*inp.values())] + + +class TestLinear(BaseTestKernel): + kernel = Linear + fields = prod({"variance": [0.1, 1.0, 2.0]}) + params = {"test_initialization": fields} + static_fields = [] + default_compute_engine = DenseKernelComputation + + +class TestPolynomial(BaseTestKernel): + kernel = Polynomial + fields = prod( + {"variance": [0.1, 1.0, 2.0], "degree": [1, 2, 3], "shift": [1e-6, 0.1, 1.0]} + ) + static_fields = ["degree"] + params = {"test_initialization": fields} + default_compute_engine = DenseKernelComputation diff --git a/tests/test_kernels/test_stationary.py b/tests/test_kernels/test_stationary.py index edcf7162..e69059b0 100644 --- a/tests/test_kernels/test_stationary.py +++ b/tests/test_kernels/test_stationary.py @@ -15,22 +15,40 @@ from itertools import product +from dataclasses import is_dataclass +import jax import jax.numpy as jnp import jax.tree_util as jtu -import pytest import tensorflow_probability.substrates.jax.bijectors as tfb import tensorflow_probability.substrates.jax.distributions as tfd + +import pytest +import tensorflow_probability.substrates.jax.distributions as tfd from jax.config import config +from gpjax.linops import LinearOperator from gpjax.kernels.base import AbstractKernel -from gpjax.kernels.computations import (ConstantDiagonalKernelComputation, - DenseKernelComputation) -from gpjax.kernels.stationary import (RBF, Matern12, Matern32, Matern52, - Periodic, PoweredExponential, - RationalQuadratic, White) +from gpjax.kernels.computations import ( + ConstantDiagonalKernelComputation, + DenseKernelComputation, +) +from gpjax.kernels.stationary import ( + RBF, + Matern12, + Matern32, + Matern52, + Periodic, + PoweredExponential, + RationalQuadratic, + White, +) +from gpjax.kernels.computations import ( + DenseKernelComputation, + ConstantDiagonalKernelComputation, +) + from gpjax.kernels.stationary.utils import build_student_t_distribution -from gpjax.linops import LinearOperator # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) @@ -45,24 +63,28 @@ class BaseTestKernel: def pytest_generate_tests(self, metafunc): """This is called automatically by pytest""" + + # function for pretty test name id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()]) - funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) + # get arguments for the test function + funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) if funcarglist is None: return else: - argnames = sorted(funcarglist[0]) - metafunc.parametrize( - argnames, - [[funcargs[name] for name in argnames] for funcargs in funcarglist], - ids=id_func, - ) + # equivalent of pytest.mark.parametrize applied on the metafunction + metafunc.parametrize("fields", funcarglist, ids=id_func) @pytest.mark.parametrize("dim", [None, 1, 3], ids=lambda x: f"dim={x}") def test_initialization(self, fields: dict, dim: int) -> None: + + # Check that kernel is a dataclass + assert is_dataclass(self.kernel) + + # Input fields as JAX arrays fields = {k: jnp.array([v]) for k, v in fields.items()} - # number of dimensions + # Test number of dimensions if dim is None: kernel: AbstractKernel = self.kernel(**fields) assert kernel.ndims == 1 @@ -72,37 +94,53 @@ def test_initialization(self, fields: dict, dim: int) -> None: ) assert kernel.ndims == dim - # compute engine + # Check default compute engine assert kernel.compute_engine == self.default_compute_engine - # properties + # Check properties for field, value in fields.items(): assert getattr(kernel, field) == value - # pytree + # Check pytree structure leaves = jtu.tree_leaves(kernel) assert len(leaves) == len(fields) + # Test dtype of params + for v in leaves: + assert v.dtype == jnp.float64 + # meta - meta_leaves = kernel._pytree__meta - assert meta_leaves.keys() == fields.keys() + meta = kernel._pytree__meta + assert meta.keys() == fields.keys() for field in fields: + + # Bijectors if field in ["variance", "lengthscale", "period", "alpha"]: - assert isinstance(meta_leaves[field]["bijector"], tfb.Softplus) + assert isinstance(meta[field]["bijector"], tfb.Softplus) if field in ["power"]: - assert isinstance(meta_leaves[field]["bijector"], tfb.Identity) - assert meta_leaves[field]["trainable"] == True + assert isinstance(meta[field]["bijector"], tfb.Identity) + + # Trainability state + assert meta[field]["trainable"] == True - # call + # Test kernel call x = jnp.linspace(0.0, 1.0, 10 * kernel.ndims).reshape(10, kernel.ndims) - kernel(x, x) + jax.vmap(kernel)(x, x) - @pytest.mark.parametrize("n", [1, 5], ids=lambda x: f"n={x}") + @pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}") @pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}") def test_gram(self, dim: int, n: int) -> None: + + # Initialise kernel kernel: AbstractKernel = self.kernel() + + # Gram constructor static method kernel.gram + + # Inputs x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Test gram matrix Kxx = kernel.gram(x) assert isinstance(Kxx, LinearOperator) assert Kxx.shape == (n, n) @@ -112,28 +150,37 @@ def test_gram(self, dim: int, n: int) -> None: @pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}") @pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}") def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None: + + # Initialise kernel kernel: AbstractKernel = self.kernel() + + # Inputs a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim) b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim) + + # Test cross-covariance Kab = kernel.cross_covariance(a, b) assert isinstance(Kab, jnp.ndarray) assert Kab.shape == (n_a, n_b) def test_spectral_density(self): + + # Initialise kernel kernel: AbstractKernel = self.kernel() if self.kernel not in [RBF, Matern12, Matern32, Matern52]: + # Check that spectral_density property is None assert not kernel.spectral_density else: + + # Check that spectral_density property is correct sdensity = kernel.spectral_density assert sdensity.name == self.spectral_density_name assert sdensity.loc == jnp.array(0.0) assert sdensity.scale == jnp.array(1.0) -prod = lambda inp: [ - {"fields": dict(zip(inp.keys(), values))} for values in product(*inp.values()) -] +prod = lambda inp: [dict(zip(inp.keys(), values)) for values in product(*inp.values())] class TestRBF(BaseTestKernel): diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 5a4648ce..0abf0f6f 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -13,125 +13,196 @@ # limitations under the License. # ============================================================================== -from typing import Callable +from typing import Callable, List +from itertools import product +from dataclasses import is_dataclass +import jax.tree_util as jtu import jax.numpy as jnp import jax.random as jr -import jax.tree_util as jtu +import tensorflow_probability.substrates.jax.bijectors as tfb +import tensorflow_probability.substrates.jax.distributions as tfd import numpy as np import pytest -import tensorflow_probability.substrates.jax as tfp from jax.config import config from jax.random import KeyArray from jaxtyping import Array, Float -from gpjax.likelihoods import (AbstractLikelihood, Bernoulli, Gaussian, - inv_probit) +from gpjax.likelihoods import ( + AbstractLikelihood, + Bernoulli, + Gaussian, + inv_probit, +) -tfd = tfp.distributions # Enable Float64 for more stable matrix inversions. config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) -def test_abstract_likelihood(): - # Test that abstract likelihoods cannot be instantiated. - with pytest.raises(TypeError): - AbstractLikelihood(num_datapoints=123) +class BaseTestLikelihood: + """A base class that contains all tests applied on likelihoods.""" - # Create a dummy likelihood class with abstract methods implemented. - class DummyLikelihood(AbstractLikelihood): - def predict(self, dist: tfd.Distribution) -> tfd.Distribution: - return tfd.Normal(0.0, 1.0) + likelihood: AbstractLikelihood + static_fields: List[str] = ["num_datapoints"] - def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: - return tfd.MultivariateNormalDiag(loc=f) + def pytest_generate_tests(self, metafunc): + """This is called automatically by pytest""" + + # function for pretty test name + id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()]) + + # get arguments for the test function + funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None) + + if funcarglist is None: + return + else: + # equivalent of pytest.mark.parametrize applied on the metafunction + metafunc.parametrize("fields", funcarglist, ids=id_func) + + @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") + def test_initialisation(self, fields: dict, n: int) -> None: + + # Check that likelihood is a dataclass + assert is_dataclass(self.likelihood) + + # Input fields as JAX arrays + fields = {k: jnp.array([v]) for k, v in fields.items()} + + # Initialise + likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n, **fields) + + # Check properties + for field, value in fields.items(): + assert getattr(likelihood, field) == value - # Test that the dummy likelihood can be instantiated. - dummy_likelihood = DummyLikelihood(num_datapoints=123) - assert isinstance(dummy_likelihood, AbstractLikelihood) + # Test that pytree returns param_field objects (and not static_field) + leaves = jtu.tree_leaves(likelihood) + assert len(leaves) == len(set(fields) - set(self.static_fields)) + # Test dtype of params + for v in leaves: + assert v.dtype == jnp.float64 -@pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) -def test_gaussian_init(n: int, noise: float) -> None: - likelihood = Gaussian(num_datapoints=n, obs_noise=jnp.array([noise])) + # Check meta leaves + meta = likelihood._pytree__meta + assert not any(f in meta.keys() for f in self.static_fields) + assert list(meta.keys()) == sorted(set(fields) - set(self.static_fields)) - assert likelihood.obs_noise == jnp.array([noise]) - assert likelihood.num_datapoints == n - assert jtu.tree_leaves(likelihood) == [jnp.array([noise])] + for field in meta: + # Bijectors + if field in ["obs_noise"]: + assert isinstance(meta[field]["bijector"], tfb.Softplus) -@pytest.mark.parametrize("n", [1, 10]) -def test_beroulli_init(n: int) -> None: - likelihood = Bernoulli(num_datapoints=n) - assert likelihood.num_datapoints == n - assert jtu.tree_leaves(likelihood) == [] + # Trainability state + assert meta[field]["trainable"] == True + @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") + def test_link_functions(self, n: int): -@pytest.mark.parametrize("lik", [Gaussian, Bernoulli]) -@pytest.mark.parametrize("n", [1, 10]) -def test_link_fns(lik: AbstractLikelihood, n: int) -> None: - # Create function values. - f = jnp.linspace(-3.0, 3.0).reshape(-1, 1) + # Initialize likelihood with defaults + likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n) - # Initialise likelihood. - likelihood = lik(num_datapoints=n) + # Create input values + x = jnp.linspace(-3.0, 3.0).reshape(-1, 1) - # Test likelihood link function. - assert isinstance(likelihood.link_function, Callable) - assert isinstance(likelihood.link_function(f), tfd.Distribution) + # Test likelihood link function. + assert isinstance(likelihood.link_function, Callable) + assert isinstance(likelihood.link_function(x), tfd.Distribution) + @pytest.mark.parametrize("n", [1, 2, 10], ids=lambda x: f"n={x}") + def test_call(self, fields: dict, n: int): -@pytest.mark.parametrize("noise", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_call_gaussian(noise: float, n: int) -> None: - key = jr.PRNGKey(123) + # Input fields as JAX arrays + fields = {k: jnp.array([v]) for k, v in fields.items()} - # Initialise likelihood and parameters. - likelihood = Gaussian(num_datapoints=n, obs_noise=jnp.array([noise])) + # Initialise + likelihood: AbstractLikelihood = self.likelihood(num_datapoints=n, **fields) - # Construct latent function distribution. - latent_mean = jr.uniform(key, shape=(n,)) - latent_sqrt = jr.uniform(key, shape=(n, n)) - latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) - latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) + # Construct latent function distribution. + k1, k2 = jr.split(_initialise_key) + latent_mean = jr.uniform(k1, shape=(n,)) + latent_sqrt = jr.uniform(k2, shape=(n, n)) + latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) + latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) - # Test call method. - pred_dist = likelihood(latent_dist) + # Perform checks specific to the given likelihood + self._test_call_check(likelihood, latent_mean, latent_cov, latent_dist) - # Check that the distribution is a MultivariateNormalFullCovariance. - assert isinstance(pred_dist, tfd.MultivariateNormalFullCovariance) + @staticmethod + def _test_call_check(likelihood, latent_mean, latent_cov, latent_dist): + """Specific to each likelihood.""" + raise NotImplementedError - # Check predictive mean and variance. - assert (pred_dist.mean() == latent_mean).all() - noise_matrix = jnp.eye(n) * noise - assert np.allclose( - pred_dist.scale_tril, jnp.linalg.cholesky(latent_cov + noise_matrix) - ) +prod = lambda inp: [dict(zip(inp.keys(), values)) for values in product(*inp.values())] -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_call_bernoulli(n: int) -> None: - key = jr.PRNGKey(123) +class TestGaussian(BaseTestLikelihood): + likelihood = Gaussian + fields = prod({"obs_noise": [0.1, 0.5, 1.0]}) + params = {"test_initialisation": fields, "test_call": fields} + static_fields = ["num_datapoints"] - # Initialise likelihood and parameters. - likelihood = Bernoulli(num_datapoints=n) + @staticmethod + def _test_call_check(likelihood: Gaussian, latent_mean, latent_cov, latent_dist): - # Construct latent function distribution. - latent_mean = jr.uniform(key, shape=(n,)) - latent_sqrt = jr.uniform(key, shape=(n, n)) - latent_cov = jnp.matmul(latent_sqrt, latent_sqrt.T) - latent_dist = tfd.MultivariateNormalFullCovariance(latent_mean, latent_cov) + # Test call method. + pred_dist = likelihood(latent_dist) - # Test call method. - pred_dist = likelihood(latent_dist) + # Check that the distribution is a MultivariateNormalFullCovariance. + assert isinstance(pred_dist, tfd.MultivariateNormalFullCovariance) - # Check that the distribution is a Bernoulli. - assert isinstance(pred_dist, tfd.Bernoulli) + # Check predictive mean and variance. + assert (pred_dist.mean() == latent_mean).all() + noise_matrix = jnp.eye(likelihood.num_datapoints) * likelihood.obs_noise + assert np.allclose( + pred_dist.scale_tril, jnp.linalg.cholesky(latent_cov + noise_matrix) + ) - # Check predictive mean and variance. - p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) - assert (pred_dist.mean() == p).all() - assert (pred_dist.variance() == p * (1.0 - p)).all() +class TestBernoulli(BaseTestLikelihood): + likelihood = Bernoulli + fields = prod({}) + params = {"test_initialisation": fields, "test_call": fields} + static_fields = ["num_datapoints"] + + @staticmethod + def _test_call_check( + likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist + ): + + # Test call method. + pred_dist = likelihood(latent_dist) + + # Check that the distribution is a Bernoulli. + assert isinstance(pred_dist, tfd.Bernoulli) + + # Check predictive mean and variance. + + p = inv_probit(latent_mean / jnp.sqrt(1.0 + jnp.diagonal(latent_cov))) + assert (pred_dist.mean() == p).all() + assert (pred_dist.variance() == p * (1.0 - p)).all() + + +class TestAbstract(BaseTestLikelihood): + class DummyLikelihood(AbstractLikelihood): + def predict(self, dist: tfd.Distribution) -> tfd.Distribution: + return tfd.Normal(0.0, 1.0) + + def link_function(self, f: Float[Array, "N 1"]) -> Float[Array, "N 1"]: + return tfd.MultivariateNormalDiag(loc=f) + + likelihood = DummyLikelihood + fields = prod({}) + params = {"test_initialisation": fields, "test_call": fields} + static_fields = ["num_datapoints"] + + @staticmethod + def _test_call_check( + likelihood: AbstractLikelihood, latent_mean, latent_cov, latent_dist + ): + pred_dist = likelihood(latent_dist) + assert isinstance(pred_dist, tfd.Normal)