Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for kernels and likelihoods #225

Merged
merged 5 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 135 additions & 133 deletions tests/test_kernels/test_nonstationary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@
# limitations under the License.
# ==============================================================================

from itertools import permutations
from itertools import permutations, product
from dataclasses import is_dataclass

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import pytest
import tensorflow_probability.substrates.jax.bijectors as tfb
from jax.config import config
from typing import List

from gpjax.kernels.base import AbstractKernel
from gpjax.kernels.computations import DenseKernelComputation
from gpjax.kernels.nonstationary import Linear, Polynomial
from gpjax.linops import LinearOperator, identity

Expand All @@ -31,135 +36,132 @@
_jitter = 1e-6


@pytest.mark.parametrize(
"kernel",
[
Linear(),
Polynomial(),
],
)
@pytest.mark.parametrize("dim", [1, 2, 5])
@pytest.mark.parametrize("n", [1, 2, 10])
def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None:
# Gram constructor static method:
kernel.gram

# Inputs x:
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test gram matrix:
Kxx = kernel.gram(x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)


@pytest.mark.parametrize(
"kernel",
[
Linear(),
Polynomial(),
],
)
@pytest.mark.parametrize("num_a", [1, 2, 5])
@pytest.mark.parametrize("num_b", [1, 2, 5])
@pytest.mark.parametrize("dim", [1, 2, 5])
def test_cross_covariance(
kernel: AbstractKernel, num_a: int, num_b: int, dim: int
) -> None:
# Inputs a, b:
a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim)
b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim)

# Test cross covariance, Kab:
Kab = kernel.cross_covariance(a, b)
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (num_a, num_b)


@pytest.mark.parametrize("kern", [Linear, Polynomial])
@pytest.mark.parametrize("dim", [1, 2, 5])
@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0])
@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5])
@pytest.mark.parametrize("n", [1, 2, 5])
def test_pos_def(
kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int
) -> None:
kern = kern(active_dims=list(range(dim)))
# Gram constructor static method:
kern.gram

# Create inputs x:
x = jr.uniform(_initialise_key, (n, dim))

if isinstance(kern, Polynomial):
kern = kern.replace(shift=shift, variance=sigma)
else:
kern = kern.replace(variance=sigma)

# Test gram matrix eigenvalues are positive:
Kxx = kern.gram(x)
Kxx += identity(n) * _jitter
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
assert (eigen_values > 0.0).all()


@pytest.mark.parametrize("degree", [1, 2, 3])
@pytest.mark.parametrize("dim", [1, 2, 5])
@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0])
@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0])
@pytest.mark.parametrize("n", [1, 2, 5])
def test_polynomial(
degree: int, dim: int, variance: float, shift: float, n: int
) -> None:
# Define inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Define kernel
kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)])

# # Check name
# assert kern.name == f"Polynomial Degree: {degree}"

# Initialise parameters
kern = kern.replace(shift=kern.shift * shift, variance=kern.variance * variance)

# Compute gram matrix
Kxx = kern.gram(x)

# Check shapes
assert Kxx.shape[0] == x.shape[0]
assert Kxx.shape[0] == Kxx.shape[1]

# Test positive definiteness
Kxx += identity(n) * _jitter
eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense())
assert (eigen_values > 0).all()


@pytest.mark.parametrize(
"kernel",
[Linear, Polynomial],
)
def test_active_dim(kernel: AbstractKernel) -> None:
dim_list = [0, 1, 2, 3]
perm_length = 2
dim_pairs = list(permutations(dim_list, r=perm_length))
n_dims = len(dim_list)

# Generate random inputs
x = jr.normal(_initialise_key, shape=(20, n_dims))

for dp in dim_pairs:
# Take slice of x
slice = x[..., dp]

# Define kernels
ad_kern = kernel(active_dims=dp)
manual_kern = kernel(active_dims=[i for i in range(perm_length)])

# Compute gram matrices
ad_Kxx = ad_kern.gram(x)
manual_Kxx = manual_kern.gram(slice)

# Test gram matrices are equal
assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense())
class BaseTestKernel:
"""A base class that contains all tests applied on non-stationary kernels."""

kernel: AbstractKernel
default_compute_engine: type
static_fields: List[str]

def pytest_generate_tests(self, metafunc):
"""This is called automatically by pytest"""

# function for pretty test name
id_func = lambda x: "-".join([f"{k}={v}" for k, v in x.items()])

# get arguments for the test function
funcarglist = metafunc.cls.params.get(metafunc.function.__name__, None)
if funcarglist is None:
return
else:
# equivalent of pytest.mark.parametrize applied on the metafunction
metafunc.parametrize("fields", funcarglist, ids=id_func)

@pytest.mark.parametrize("dim", [None, 1, 3], ids=lambda x: f"dim={x}")
def test_initialization(self, fields: dict, dim: int) -> None:

# Check that kernel is a dataclass
assert is_dataclass(self.kernel)

# Input fields as JAX arrays
fields = {k: jnp.array([v]) for k, v in fields.items()}

# Test number of dimensions
if dim is None:
kernel: AbstractKernel = self.kernel(**fields)
assert kernel.ndims == 1
else:
kernel: AbstractKernel = self.kernel(
active_dims=[i for i in range(dim)], **fields
)
assert kernel.ndims == dim

# Check default compute engine
assert kernel.compute_engine == self.default_compute_engine

# Check properties
for field, value in fields.items():
assert getattr(kernel, field) == value

# Test that pytree returns param_field objects (and not static_field)
leaves = jtu.tree_leaves(kernel)
assert len(leaves) == len(set(fields) - set(self.static_fields))

# Test dtype of params
for v in leaves:
assert v.dtype == jnp.float64

# Check meta leaves
meta = kernel._pytree__meta
assert not any(f in meta.keys() for f in self.static_fields)
assert list(meta.keys()) == sorted(set(fields) - set(self.static_fields))

for field in meta:

# Bijectors
if field in ["variance", "shift"]:
assert isinstance(meta[field]["bijector"], tfb.Softplus)

# Trainability state
assert meta[field]["trainable"] == True

# Test kernel call
x = jnp.linspace(0.0, 1.0, 10 * kernel.ndims).reshape(10, kernel.ndims)
jax.vmap(kernel)(x, x)

@pytest.mark.parametrize("n", [1, 2, 5], ids=lambda x: f"n={x}")
@pytest.mark.parametrize("dim", [1, 3], ids=lambda x: f"dim={x}")
def test_gram(self, dim: int, n: int) -> None:

# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Gram constructor static method
kernel.gram

# Inputs
x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim)

# Test gram matrix
Kxx = kernel.gram(x)
assert isinstance(Kxx, LinearOperator)
assert Kxx.shape == (n, n)
assert jnp.all(jnp.linalg.eigvalsh(Kxx.to_dense() + jnp.eye(n) * 1e-6) > 0.0)

@pytest.mark.parametrize("n_a", [1, 2, 5], ids=lambda x: f"n_a={x}")
@pytest.mark.parametrize("n_b", [1, 2, 5], ids=lambda x: f"n_b={x}")
@pytest.mark.parametrize("dim", [1, 2, 5], ids=lambda x: f"dim={x}")
def test_cross_covariance(self, n_a: int, n_b: int, dim: int) -> None:

# Initialise kernel
kernel: AbstractKernel = self.kernel()

# Inputs
a = jnp.linspace(-1.0, 1.0, n_a * dim).reshape(n_a, dim)
b = jnp.linspace(3.0, 4.0, n_b * dim).reshape(n_b, dim)

# Test cross-covariance
Kab = kernel.cross_covariance(a, b)
assert isinstance(Kab, jnp.ndarray)
assert Kab.shape == (n_a, n_b)


prod = lambda inp: [dict(zip(inp.keys(), values)) for values in product(*inp.values())]


class TestLinear(BaseTestKernel):
kernel = Linear
fields = prod({"variance": [0.1, 1.0, 2.0]})
params = {"test_initialization": fields}
static_fields = []
default_compute_engine = DenseKernelComputation


class TestPolynomial(BaseTestKernel):
kernel = Polynomial
fields = prod(
{"variance": [0.1, 1.0, 2.0], "degree": [1, 2, 3], "shift": [1e-6, 0.1, 1.0]}
)
static_fields = ["degree"]
params = {"test_initialization": fields}
default_compute_engine = DenseKernelComputation
Loading