From 4fdb356f2932beccfa81e327fbf1acb6cf932650 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 11:51:05 +0000 Subject: [PATCH 01/15] Remove types.py and test_types.py --- gpjax/types.py | 79 --------------------------------------------- tests/test_types.py | 71 ---------------------------------------- 2 files changed, 150 deletions(-) delete mode 100644 gpjax/types.py delete mode 100644 tests/test_types.py diff --git a/gpjax/types.py b/gpjax/types.py deleted file mode 100644 index b3fedb97..00000000 --- a/gpjax/types.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -from chex import dataclass -from jaxtyping import Array, Float -import deprecation - -NoneType = type(None) - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxUtils for a Dataset object", -) -@dataclass -class Dataset: - """GPJax Dataset class.""" - - X: Float[Array, "N D"] - y: Float[Array, "N 1"] = None - - def __repr__(self) -> str: - return ( - f"- Number of datapoints: {self.X.shape[0]}\n- Dimension:" - f" {self.X.shape[1]}" - ) - - def __add__(self, other: "Dataset") -> "Dataset": - """Combines two datasets into one. The right-hand dataset is stacked beneath left.""" - x = jnp.concatenate((self.X, other.X)) - y = jnp.concatenate((self.y, other.y)) - - return Dataset(X=x, y=y) - - @property - def n(self) -> int: - """The number of observations in the dataset.""" - return self.X.shape[0] - - @property - def in_dim(self) -> int: - """The dimension of the input data.""" - return self.X.shape[1] - - @property - def out_dim(self) -> int: - """The dimension of the output data.""" - return self.y.shape[1] - - -def verify_dataset(ds: Dataset) -> None: - """Apply a series of checks to the dataset to ensure that downstream operations are safe.""" - assert ds.X.ndim == 2, ( - "2-dimensional training inputs are required. Current dimension:" - f" {ds.X.ndim}." - ) - if ds.y is not None: - assert ds.y.ndim == 2, ( - "2-dimensional training outputs are required. Current dimension:" - f" {ds.y.ndim}." - ) - assert ds.X.shape[0] == ds.y.shape[0], ( - "Number of inputs must equal the number of outputs. \nCurrent" - f" counts:\n- X: {ds.X.shape[0]}\n- y: {ds.y.shape[0]}" - ) diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index d4b11a94..00000000 --- a/tests/test_types.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import pytest -from jax.config import config - -from gpjax.types import Dataset, NoneType, verify_dataset - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_nonetype(): - assert isinstance(None, NoneType) - - -@pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -@pytest.mark.parametrize("n2", [1, 10]) -def test_dataset(n, outd, ind, n2): - x = jnp.ones((n, ind)) - y = jnp.ones((n, outd)) - d = Dataset(X=x, y=y) - verify_dataset(d) - assert d.n == n - assert d.in_dim == ind - assert d.out_dim == outd - - # test combine datasets - x2 = 2 * jnp.ones((n2, ind)) - y2 = 2 * jnp.ones((n2, outd)) - d2 = Dataset(X=x2, y=y2) - - d_combined = d + d2 - assert d_combined.n == n + n2 - assert d_combined.in_dim == ind - assert d_combined.out_dim == outd - assert (d_combined.y[:n] == 1.0).all() - assert (d_combined.y[n:] == 2.0).all() - assert (d_combined.X[:n] == 1.0).all() - assert (d_combined.X[n:] == 2.0).all() - - -@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) -def test_dataset_assertions(nx, ny): - x = jnp.ones((nx, 1)) - y = jnp.ones((ny, 1)) - with pytest.raises(AssertionError): - ds = Dataset(X=x, y=y) - verify_dataset(ds) - - -def test_y_none(): - x = jnp.ones((10, 1)) - d = Dataset(X=x) - verify_dataset(d) - assert d.y is None From 1054df0e423ff16d099c642e418c095d7c4f1696 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 12:16:18 +0000 Subject: [PATCH 02/15] Revert "Remove types.py and test_types.py" This reverts commit 4fdb356f2932beccfa81e327fbf1acb6cf932650. --- gpjax/types.py | 79 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_types.py | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 150 insertions(+) create mode 100644 gpjax/types.py create mode 100644 tests/test_types.py diff --git a/gpjax/types.py b/gpjax/types.py new file mode 100644 index 00000000..b3fedb97 --- /dev/null +++ b/gpjax/types.py @@ -0,0 +1,79 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +from chex import dataclass +from jaxtyping import Array, Float +import deprecation + +NoneType = type(None) + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxUtils for a Dataset object", +) +@dataclass +class Dataset: + """GPJax Dataset class.""" + + X: Float[Array, "N D"] + y: Float[Array, "N 1"] = None + + def __repr__(self) -> str: + return ( + f"- Number of datapoints: {self.X.shape[0]}\n- Dimension:" + f" {self.X.shape[1]}" + ) + + def __add__(self, other: "Dataset") -> "Dataset": + """Combines two datasets into one. The right-hand dataset is stacked beneath left.""" + x = jnp.concatenate((self.X, other.X)) + y = jnp.concatenate((self.y, other.y)) + + return Dataset(X=x, y=y) + + @property + def n(self) -> int: + """The number of observations in the dataset.""" + return self.X.shape[0] + + @property + def in_dim(self) -> int: + """The dimension of the input data.""" + return self.X.shape[1] + + @property + def out_dim(self) -> int: + """The dimension of the output data.""" + return self.y.shape[1] + + +def verify_dataset(ds: Dataset) -> None: + """Apply a series of checks to the dataset to ensure that downstream operations are safe.""" + assert ds.X.ndim == 2, ( + "2-dimensional training inputs are required. Current dimension:" + f" {ds.X.ndim}." + ) + if ds.y is not None: + assert ds.y.ndim == 2, ( + "2-dimensional training outputs are required. Current dimension:" + f" {ds.y.ndim}." + ) + assert ds.X.shape[0] == ds.y.shape[0], ( + "Number of inputs must equal the number of outputs. \nCurrent" + f" counts:\n- X: {ds.X.shape[0]}\n- y: {ds.y.shape[0]}" + ) diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..d4b11a94 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,71 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +import pytest +from jax.config import config + +from gpjax.types import Dataset, NoneType, verify_dataset + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + + +def test_nonetype(): + assert isinstance(None, NoneType) + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +@pytest.mark.parametrize("n2", [1, 10]) +def test_dataset(n, outd, ind, n2): + x = jnp.ones((n, ind)) + y = jnp.ones((n, outd)) + d = Dataset(X=x, y=y) + verify_dataset(d) + assert d.n == n + assert d.in_dim == ind + assert d.out_dim == outd + + # test combine datasets + x2 = 2 * jnp.ones((n2, ind)) + y2 = 2 * jnp.ones((n2, outd)) + d2 = Dataset(X=x2, y=y2) + + d_combined = d + d2 + assert d_combined.n == n + n2 + assert d_combined.in_dim == ind + assert d_combined.out_dim == outd + assert (d_combined.y[:n] == 1.0).all() + assert (d_combined.y[n:] == 2.0).all() + assert (d_combined.X[:n] == 1.0).all() + assert (d_combined.X[n:] == 2.0).all() + + +@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) +def test_dataset_assertions(nx, ny): + x = jnp.ones((nx, 1)) + y = jnp.ones((ny, 1)) + with pytest.raises(AssertionError): + ds = Dataset(X=x, y=y) + verify_dataset(ds) + + +def test_y_none(): + x = jnp.ones((10, 1)) + d = Dataset(X=x) + verify_dataset(d) + assert d.y is None From b54f4865b24b7b08dae3a8332c7118a3b758a60f Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 14:19:45 +0000 Subject: [PATCH 03/15] Remove chex from codebase. * Remove chex. * Remove NoneType * Ensure everything is depreciated in kernels.py, and types.py * Removed tests for kernels and types. * Replace `n_iters` with `num_iters` in training abstractions. This is more concistent e.g., `num_datapoints` in likelihoods. --- README.md | 2 +- examples/barycentres.pct.py | 2 +- examples/classification.pct.py | 2 +- examples/collapsed_vi.pct.py | 2 +- examples/graph_kernels.pct.py | 2 +- examples/haiku.pct.py | 2 +- examples/kernels.pct.py | 2 +- examples/natgrads.pct.py | 2 +- examples/regression.pct.py | 2 +- examples/uncollapsed_vi.pct.py | 4 +- examples/yacht.pct.py | 2 +- gpjax/__init__.py | 2 +- gpjax/abstractions.py | 101 ++- gpjax/gps.py | 137 ++-- gpjax/kernels.py | 1123 +-------------------------- gpjax/likelihoods.py | 42 +- gpjax/mean_functions.py | 57 +- gpjax/parameters.py | 17 +- gpjax/test_variational_inference.py | 159 ++++ gpjax/types.py | 65 +- gpjax/variational_families.py | 127 ++- gpjax/variational_inference.py | 53 +- tests/test_abstractions.py | 26 +- tests/test_kernels.py | 599 -------------- tests/test_types.py | 71 -- tests/test_variational_inference.py | 2 - 26 files changed, 571 insertions(+), 2034 deletions(-) create mode 100644 gpjax/test_variational_inference.py delete mode 100644 tests/test_kernels.py delete mode 100644 tests/test_types.py diff --git a/README.md b/README.md index 0e1f60ec..ac5def1c 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ parameter_state = gpx.initialise(posterior, key=key) Finally, we run an optimisation loop using the Adam optimiser via the `fit` callable. ```python -inference_state = gpx.fit(mll, parameter_state, opt, n_iters=500) +inference_state = gpx.fit(mll, parameter_state, opt, num_iters=500) ``` ## 3. Making predictions diff --git a/examples/barycentres.pct.py b/examples/barycentres.pct.py index c4f3dbde..035efb70 100644 --- a/examples/barycentres.pct.py +++ b/examples/barycentres.pct.py @@ -115,7 +115,7 @@ def fit_gp(x: jax.Array, y: jax.Array) -> dx.MultivariateNormalTri: objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/classification.pct.py b/examples/classification.pct.py index c0996daa..dddb32d2 100644 --- a/examples/classification.pct.py +++ b/examples/classification.pct.py @@ -91,7 +91,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) map_estimate, training_history = inference_state.unpack() diff --git a/examples/collapsed_vi.pct.py b/examples/collapsed_vi.pct.py index 244c4cdf..725ade7a 100644 --- a/examples/collapsed_vi.pct.py +++ b/examples/collapsed_vi.pct.py @@ -109,7 +109,7 @@ objective=negative_elbo, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=2000, + num_iters=2000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/graph_kernels.pct.py b/examples/graph_kernels.pct.py index dc7b1a2d..6cb52c65 100644 --- a/examples/graph_kernels.pct.py +++ b/examples/graph_kernels.pct.py @@ -137,7 +137,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/haiku.pct.py b/examples/haiku.pct.py index 166ef065..e98978f5 100644 --- a/examples/haiku.pct.py +++ b/examples/haiku.pct.py @@ -185,7 +185,7 @@ def forward(x): objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=2500, + num_iters=2500, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index a2f4d748..5f9b73c6 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -312,7 +312,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, ) learned_params, training_history = inference_state.unpack() diff --git a/examples/natgrads.pct.py b/examples/natgrads.pct.py index dae45793..c72bc565 100644 --- a/examples/natgrads.pct.py +++ b/examples/natgrads.pct.py @@ -95,7 +95,7 @@ natural_svgp, parameter_state=parameter_state, train_data=D, - n_iters=5000, + num_iters=5000, batch_size=256, key=jr.PRNGKey(42), moment_optim=ox.sgd(0.01), diff --git a/examples/regression.pct.py b/examples/regression.pct.py index f440e137..708591a9 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -185,7 +185,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=500, + num_iters=500, ) # %% [markdown] diff --git a/examples/uncollapsed_vi.pct.py b/examples/uncollapsed_vi.pct.py index 930a50a7..9e03c650 100644 --- a/examples/uncollapsed_vi.pct.py +++ b/examples/uncollapsed_vi.pct.py @@ -171,7 +171,7 @@ parameter_state=parameter_state, train_data=D, optax_optim=optimiser, - n_iters=3000, + num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) @@ -225,7 +225,7 @@ parameter_state=parameter_state, train_data=D, optax_optim=optimiser, - n_iters=3000, + num_iters=3000, key=jr.PRNGKey(42), batch_size=128, ) diff --git a/examples/yacht.pct.py b/examples/yacht.pct.py index a2fa97a3..d557bfae 100644 --- a/examples/yacht.pct.py +++ b/examples/yacht.pct.py @@ -136,7 +136,7 @@ objective=negative_mll, parameter_state=parameter_state, optax_optim=optimiser, - n_iters=1000, + num_iters=1000, log_rate=50, ) diff --git a/gpjax/__init__.py b/gpjax/__init__.py index 331da223..d51754f5 100644 --- a/gpjax/__init__.py +++ b/gpjax/__init__.py @@ -28,7 +28,6 @@ from .likelihoods import Bernoulli, Gaussian from .mean_functions import Constant, Zero from .parameters import constrain, copy_dict_structure, initialise, unconstrain -from .types import Dataset from .variational_families import ( CollapsedVariationalGaussian, ExpectationVariationalGaussian, @@ -36,6 +35,7 @@ VariationalGaussian, WhitenedVariationalGaussian, ) +from .types import Dataset from .variational_inference import CollapsedVI, StochasticVI from . import _version diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index d9e45f77..a99be71f 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -19,7 +19,8 @@ import jax.numpy as jnp import jax.random as jr import optax as ox -from chex import dataclass, PRNGKey as PRNGKeyType + +from jax.random import KeyArray from jax import lax from jax.experimental import host_callback from jaxtyping import Array, Float @@ -27,22 +28,40 @@ from .natural_gradients import natural_gradients from .parameters import ParameterState, constrain, trainable_params, unconstrain -from jaxutils import Dataset +from jaxutils import Dataset, PyTree from .variational_inference import StochasticVI -@dataclass(frozen=True) -class InferenceState: - """Imutable dataclass for storing optimised parameters and training history.""" +class InferenceState(PyTree): + """Imutable class for storing optimised parameters and training history.""" + + def __init__(self, params: Dict, history: Float[Array, "num_iters"]): + self._params = params + self._history = history + + @property + def params(self) -> Dict: + """Parameters. + + Returns: + Dict: Parameters. + """ + return self._params + + @property + def history(self) -> Float[Array, "num_iters"]: + """Training history. - params: Dict - history: Float[Array, "n_iters"] + Returns: + Float[Array, "num_iters"]: Training history. + """ + return self._history - def unpack(self) -> Tuple[Dict, Float[Array, "n_iters"]]: + def unpack(self) -> Tuple[Dict, Float[Array, "num_iters"]]: """Unpack parameters and training history into a tuple. Returns: - Tuple[Dict, Float[Array, "n_iters"]]: Tuple of parameters and training history. + Tuple[Dict, Float[Array, "num_iters"]]: Tuple of parameters and training history. """ return self.params, self.history @@ -51,7 +70,7 @@ def fit( objective: Callable, parameter_state: ParameterState, optax_optim: ox.GradientTransformation, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> InferenceState: @@ -62,9 +81,9 @@ def fit( objective (Callable): The objective function that we are optimising with respect to. parameter_state (ParameterState): The initial parameter state. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. @@ -85,7 +104,7 @@ def loss(params: Dict) -> Float[Array, "1"]: opt_state = optax_optim.init(params) # Iteration loop numbers to scan over - iter_nums = jnp.arange(n_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num: int): @@ -98,7 +117,7 @@ def step(carry, iter_num: int): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), iter_nums) @@ -114,9 +133,9 @@ def fit_batches( parameter_state: ParameterState, train_data: Dataset, optax_optim: ox.GradientTransformation, - key: PRNGKeyType, + key: KeyArray, batch_size: int, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> InferenceState: @@ -129,11 +148,11 @@ def fit_batches( parameter_state (ParameterState): The parameters for which we would like to minimise our objective function with. train_data (Dataset): The training dataset. optax_optim (GradientTransformation): The Optax optimiser that is to be used for learning a parameter set. - key (PRNGKeyType): The PRNG key for the mini-batch sampling. - batch_size(int): The batch_size. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + key (KeyArray): The PRNG key for the mini-batch sampling. + batch_size (int): The batch_size. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: InferenceState: An InferenceState object comprising the optimised parameters and training history respectively. @@ -154,8 +173,8 @@ def loss(params: Dict, batch: Dataset) -> Float[Array, "1"]: opt_state = optax_optim.init(params) # Mini-batch random keys and iteration loop numbers to scan over - keys = jr.split(key, n_iters) - iter_nums = jnp.arange(n_iters) + keys = jr.split(key, num_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num__and__key): @@ -173,7 +192,7 @@ def step(carry, iter_num__and__key): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _), history = jax.lax.scan(step, (params, opt_state), (iter_nums, keys)) @@ -184,7 +203,7 @@ def step(carry, iter_num__and__key): return InferenceState(params=params, history=history) -def get_batch(train_data: Dataset, batch_size: int, key: PRNGKeyType) -> Dataset: +def get_batch(train_data: Dataset, batch_size: int, key: KeyArray) -> Dataset: """Batch the data into mini-batches. Sampling is done with replacement. Args: @@ -208,9 +227,9 @@ def fit_natgrads( train_data: Dataset, moment_optim: ox.GradientTransformation, hyper_optim: ox.GradientTransformation, - key: PRNGKeyType, + key: KeyArray, batch_size: int, - n_iters: Optional[int] = 100, + num_iters: Optional[int] = 100, log_rate: Optional[int] = 10, verbose: Optional[bool] = True, ) -> Dict: @@ -226,11 +245,11 @@ def fit_natgrads( train_data (Dataset): The training dataset. moment_optim (GradientTransformation): The Optax optimiser for the natural gradient updates on the moments. hyper_optim (GradientTransformation): The Optax optimiser for gradient updates on the hyperparameters. - key (PRNGKeyType): The PRNG key for the mini-batch sampling. + key (KeyArray): The PRNG key for the mini-batch sampling. batch_size(int): The batch_size. - n_iters (int, optional): The number of optimisation steps to run. Defaults to 100. - log_rate (int, optional): How frequently the objective function's value should be printed. Defaults to 10. - verbose (bool, optional): Whether to print the training loading bar. Defaults to True. + num_iters (Optional[int]): The number of optimisation steps to run. Defaults to 100. + log_rate (Optional[int]): How frequently the objective function's value should be printed. Defaults to 10. + verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: InferenceState: A dataclass comprising optimised parameters and training history. @@ -251,8 +270,8 @@ def fit_natgrads( ) # Mini-batch random keys and iteration loop numbers to scan over - keys = jax.random.split(key, n_iters) - iter_nums = jnp.arange(n_iters) + keys = jax.random.split(key, num_iters) + iter_nums = jnp.arange(num_iters) # Optimisation step def step(carry, iter_num__and__key): @@ -276,7 +295,7 @@ def step(carry, iter_num__and__key): # Display progress bar if verbose is True if verbose: - step = progress_bar_scan(n_iters, log_rate)(step) + step = progress_bar_scan(num_iters, log_rate)(step) # Run the optimisation loop (params, _, _), history = jax.lax.scan( @@ -289,15 +308,15 @@ def step(carry, iter_num__and__key): return InferenceState(params=params, history=history) -def progress_bar_scan(n_iters: int, log_rate: int) -> Callable: +def progress_bar_scan(num_iters: int, log_rate: int) -> Callable: """Progress bar for Jax.lax scans (adapted from https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/).""" tqdm_bars = {} - remainder = n_iters % log_rate + remainder = num_iters % log_rate def _define_tqdm(args: Any, transform: Any) -> None: """Define a tqdm progress bar.""" - tqdm_bars[0] = tqdm(range(n_iters)) + tqdm_bars[0] = tqdm(range(num_iters)) def _update_tqdm(args: Any, transform: Any) -> None: """Update the tqdm progress bar with the latest objective value.""" @@ -329,10 +348,10 @@ def _update_progress_bar(loss_val: Float[Array, "1"], iter_num: int) -> None: # Conditions for iteration number is_first: bool = iter_num == 0 is_multiple: bool = (iter_num % log_rate == 0) & ( - iter_num != n_iters - remainder + iter_num != num_iters - remainder ) - is_remainder: bool = iter_num == n_iters - remainder - is_last: bool = iter_num == n_iters - 1 + is_remainder: bool = iter_num == num_iters - remainder + is_last: bool = iter_num == num_iters - 1 # Define progress bar, if first iteration _callback(is_first, _define_tqdm, None) diff --git a/gpjax/gps.py b/gpjax/gps.py index 82f0d660..a3b0fd7c 100644 --- a/gpjax/gps.py +++ b/gpjax/gps.py @@ -18,23 +18,23 @@ import distrax as dx import jax.numpy as jnp -from chex import dataclass, PRNGKey as PRNGKeyType from jaxtyping import Array, Float +from jax.random import KeyArray from jaxlinop import identity from jaxkern.kernels import AbstractKernel +from jaxutils import PyTree from .config import get_global_config from .kernels import AbstractKernel -from .likelihoods import AbstractLikelihood, Conjugate, Gaussian, NonConjugate +from .likelihoods import AbstractLikelihood, Conjugate, NonConjugate from .mean_functions import AbstractMeanFunction, Zero from jaxutils import Dataset from .utils import concat_dictionaries from .gaussian_distribution import GaussianDistribution -@dataclass -class AbstractPrior: +class AbstractPrior(PyTree): """Abstract Gaussian process prior. All Gaussian processes priors should inherit from this class. @@ -79,7 +79,7 @@ def predict(self, *args: Any, **kwargs: Any) -> dx.Distribution: raise NotImplementedError @abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """An initialisation method for the GP's parameters. This method should be implemented for all classes that inherit the ``AbstractPrior`` class. Whilst not always necessary, the method accepts a PRNG key to allow @@ -87,7 +87,7 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: through the ``initialise`` function given in GPJax. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -98,7 +98,8 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ####################### # GP Priors ####################### -@dataclass + + class Prior(AbstractPrior): """A Gaussian process prior object. The GP is parameterised by a `mean `_ @@ -120,17 +121,25 @@ class Prior(AbstractPrior): >>> >>> kernel = gpx.kernels.RBF() >>> prior = gpx.Prior(kernel = kernel) - - Attributes: - kernel (AbstractKernel): The kernel function used to parameterise the prior. - mean_function (MeanFunction): The mean function used to parameterise the - prior. Defaults to zero. - name (str): The name of the GP prior. Defaults to "GP prior". """ - kernel: AbstractKernel - mean_function: Optional[AbstractMeanFunction] = Zero() - name: Optional[str] = "GP prior" + def __init__( + self, + kernel: AbstractKernel, + mean_function: Optional[AbstractMeanFunction] = Zero(), + name: Optional[str] = "GP prior", + ) -> None: + """Initialise the GP prior. + + Args: + kernel (AbstractKernel): The kernel function used to parameterise the prior. + mean_function (Optional[MeanFunction]): The mean function used to parameterise the + prior. Defaults to zero. + name (Optional[str]): The name of the GP prior. Defaults to "GP prior". + """ + self.kernel = kernel + self.mean_function = mean_function + self.name = name def __mul__(self, other: AbstractLikelihood): """The product of a prior and likelihood is proportional to the @@ -230,11 +239,11 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the GP prior's parameter set. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -248,7 +257,6 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ####################### # GP Posteriors ####################### -@dataclass class AbstractPosterior(AbstractPrior): """The base GP posterior object conditioned on an observed dataset. All posterior objects should inherit from this class. @@ -257,17 +265,24 @@ class AbstractPosterior(AbstractPrior): `_. Since dataclasses take over ``__init__``, the ``__post_init__`` method can be used to initialise the GP's parameters. - - Attributes: - prior (Prior): The prior distribution of the GP. - likelihood (AbstractLikelihood): The likelihood distribution of the - observed dataset. - name (str): The name of the GP posterior. Defaults to "GP posterior". """ - prior: Prior - likelihood: AbstractLikelihood - name: Optional[str] = "GP posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise the GP posterior object. + + Args: + prior (Prior): The prior distribution of the GP. + likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. + name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: @@ -285,11 +300,11 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: """ raise NotImplementedError - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a GP posterior. Args: - key (PRNGKeyType): The PRNG key. + key (KeyArray): The PRNG key. Returns: Dict: The initialised parameter set. @@ -300,7 +315,6 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: ) -@dataclass class ConjugatePosterior(AbstractPosterior): """A Gaussian process posterior distribution when the constituent likelihood function is a Gaussian distribution. In such cases, the latent function values @@ -332,16 +346,24 @@ class ConjugatePosterior(AbstractPosterior): >>> likelihood = gpx.likelihoods.Gaussian() >>> >>> posterior = prior * likelihood - - Attributes: - prior (Prior): The prior distribution of the GP. - likelihood (Gaussian): The Gaussian likelihood distribution of the observed dataset. - name (str): The name of the GP posterior. Defaults to "Conjugate posterior". """ - prior: Prior - likelihood: Gaussian - name: Optional[str] = "Conjugate posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise the conjugate GP posterior object. + + Args: + prior (Prior): The prior distribution of the GP. + likelihood (AbstractLikelihood): The likelihood distribution of the observed dataset. + name (Optional[str]): The name of the GP posterior. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name def predict( self, @@ -501,7 +523,7 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - negative (bool, optional): Whether or not the returned function + negative (Optional[bool]): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. @@ -560,7 +582,6 @@ def mll( return mll -@dataclass class NonConjugatePosterior(AbstractPosterior): """ A Gaussian process posterior object for models where the likelihood is @@ -571,24 +592,30 @@ class NonConjugatePosterior(AbstractPosterior): hyperparameters and the latent function. Markov chain Monte Carlo, variational inference, or Laplace approximations can then be used to sample from, or optimise an approximation to, the posterior distribution. - - Attributes: - prior (AbstractPrior): The Gaussian process prior distribution. - likelihood (AbstractLikelihood): The likelihood function that - represents the data. - name (str): The name of the posterior object. Defaults to - "Non-conjugate posterior". """ - prior: AbstractPrior - likelihood: AbstractLikelihood - name: Optional[str] = "Non-conjugate posterior" + def __init__( + self, + prior: AbstractPrior, + likelihood: AbstractLikelihood, + name: Optional[str] = "GP posterior", + ) -> None: + """Initialise a non-conjugate Gaussian process posterior object. + + Args: + prior (AbstractPrior): The Gaussian process prior distribution. + likelihood (AbstractLikelihood): The likelihood function that represents the data. + name (Optional[str]): The name of the posterior object. Defaults to "GP posterior". + """ + self.prior = prior + self.likelihood = likelihood + self.name = name - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a non-conjugate GP posterior. Args: - key (PRNGKeyType): A PRNG key used to initialise the parameters. + key (KeyArray): A PRNG key used to initialise the parameters. Returns: Dict: A dictionary containing the default parameter set. @@ -620,7 +647,7 @@ def predict( and output data used for training dataset. Returns: - tp.Callable[[Array], dx.Distribution]: A function that accepts an + Callable[[Array], dx.Distribution]: A function that accepts an input array and returns the predictive distribution as a ``dx.Distribution``. """ @@ -697,7 +724,7 @@ def marginal_log_likelihood( Args: train_data (Dataset): The training dataset used to compute the marginal log-likelihood. - negative (bool, optional): Whether or not the returned function + negative (Optional[bool]): Whether or not the returned function should be negative. For optimisation, the negative is useful as minimisation of the negative marginal log-likelihood is equivalent to maximisation of the marginal log-likelihood. Defaults to False. diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 7800d1b1..2956e501 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -13,1099 +13,42 @@ # limitations under the License. # ============================================================================== -import abc -from typing import Callable, Dict, List, Optional, Sequence - -from jaxlinop import ( - LinearOperator, - DenseLinearOperator, - DiagonalLinearOperator, - ConstantDiagonalLinearOperator, -) - -import jax.numpy as jnp -from jax import vmap -import jax -from jaxtyping import Array, Float - -from chex import PRNGKey as PRNGKeyType -from jaxutils import PyTree +import jaxkern import deprecation -class AbstractKernelComputation(PyTree): - """Abstract class for kernel computations.""" - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - self._kernel_fn = kernel_fn - - @property - def kernel_fn( - self, - ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: - return self._kernel_fn - - @kernel_fn.setter - def kernel_fn( - self, - kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], - ) -> None: - self._kernel_fn = kernel_fn - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> LinearOperator: - - """Compute Gram covariance operator of the kernel function. - - Args: - kernel (AbstractKernel): The kernel function to be evaluated. - params (Dict): The parameters of the kernel function. - inputs (Float[Array, "N N"]): The inputs to the kernel function. - - Returns: - LinearOperator: Gram covariance operator of the kernel function. - """ - - matrix = self.cross_covariance(params, inputs, inputs) - - return DenseLinearOperator(matrix=matrix) - - @abc.abstractmethod - def cross_covariance( - self, - params: Dict, - x: Float[Array, "N D"], - y: Float[Array, "M D"], - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM gram matrix on an a pair - of input matrices with shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the cross-covariance - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The first input matrix. - y (Float[Array,"M D"]): The second input matrix. - - Returns: - Float[Array, "N M"]: The computed square Gram matrix. - """ - raise NotImplementedError - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the DenseKernelComputation", -) -class DenseKernelComputation(AbstractKernelComputation): - """Dense kernel computation class. Operations with the kernel assume - a dense gram matrix structure. - """ - - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - """For a given kernel, compute the NxM covariance matrix on a pair of input - matrices of shape NxD and MxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram - matrix should be computed for. - params (Dict): The kernel's parameter set. - x (Float[Array,"N D"]): The input matrix. - y (Float[Array,"M D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) - return cross_cov - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the DiagonalKernelComputation", -) -class DiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for diagonal kernels.") - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the ConstantDiagonalKernelComputation", -) -class ConstantDiagonalKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - - def gram( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> ConstantDiagonalLinearOperator: - """For a kernel with diagonal structure, compute the NxN gram matrix on - an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the Gram matrix - should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - CovarianceOperator: The computed square Gram matrix. - """ - - value = self.kernel_fn(params, inputs[0], inputs[0]) - - return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) - - def diagonal( - self, - params: Dict, - inputs: Float[Array, "N D"], - ) -> DiagonalLinearOperator: - """For a given kernel, compute the elementwise diagonal of the - NxN gram matrix on an input matrix of shape NxD. - - Args: - kernel (AbstractKernel): The kernel for which the variance - vector should be computed for. - params (Dict): The kernel's parameter set. - inputs (Float[Array, "N D"]): The input matrix. - - Returns: - LinearOperator: The computed diagonal variance entries. - """ - - diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) - - return DiagonalLinearOperator(diag=diag) - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - raise ValueError("Cross covariance not defined for constant diagonal kernels.") - - -########################################## -# Abtract classes -########################################## -class AbstractKernel(PyTree): - """ - Base kernel class""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "AbstractKernel", - ) -> None: - self.compute_engine = compute_engine - self.active_dims = active_dims - self.stationary = stationary - self.spectral = spectral - self.name = name - self.ndims = 1 if not self.active_dims else len(self.active_dims) - compute_engine = self.compute_engine(kernel_fn=self.__call__) - self.gram = compute_engine.gram - self.cross_covariance = compute_engine.cross_covariance - - @abc.abstractmethod - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - raise NotImplementedError - - def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: - """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. - - Args: - x (Float[Array, "N D"]): The matrix or vector that is to be sliced. - Returns: - Float[Array, "N Q"]: A sliced form of the input matrix. - """ - return x[..., self.active_dims] - - def __add__(self, other: "AbstractKernel") -> "AbstractKernel": - """Add two kernels together. - Args: - other (AbstractKernel): The kernel to be added to the current kernel. - - Returns: - AbstractKernel: A new kernel that is the sum of the two kernels. - """ - return SumKernel(kernel_set=[self, other]) - - def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": - """Multiply two kernels together. - - Args: - other (AbstractKernel): The kernel to be multiplied with the current kernel. - - Returns: - AbstractKernel: A new kernel that is the product of the two kernels. - """ - return ProductKernel(kernel_set=[self, other]) - - @property - def ard(self): - """Boolean property as to whether the kernel is isotropic or of - automatic relevance determination form. - - Returns: - bool: True if the kernel is an ARD kernel. - """ - return True if self.ndims > 1 else False - - @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: - """A template dictionary of the kernel's parameter set. - - Args: - key (PRNGKeyType): A PRNG key to be used for initialising - the kernel's parameters. - - Returns: - Dict: A dictionary of the kernel's parameters. - """ - raise NotImplementedError - - -class CombinationKernel(AbstractKernel): - """A base class for products or sums of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "AbstractKernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.kernel_set = kernel_set - name: Optional[str] = "Combination kernel" - self.combination_fn: Optional[Callable] = None - - if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): - raise TypeError("can only combine Kernel instances") # pragma: no cover - - self._set_kernels(self.kernel_set) - - def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: - """Combine multiple kernels. Based on GPFlow's Combination kernel.""" - # add kernels to a list, flattening out instances of this class therein - kernels_list: List[AbstractKernel] = [] - for k in kernels: - if isinstance(k, self.__class__): - kernels_list.extend(k.kernel_set) - else: - kernels_list.append(k) - - self.kernel_set = kernels_list - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - """A template dictionary of the kernel's parameter set.""" - return [kernel._initialise_params(key) for kernel in self.kernel_set] - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate combination kernel on a pair of inputs. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - return self.combination_fn( - jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) - ) - - -class SumKernel(CombinationKernel): - """A kernel that is the sum of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Sum kernel", - ) -> None: - super().__init__( - kernel_set, compute_engine, active_dims, stationary, spectral, name - ) - self.combination_fn: Optional[Callable] = jnp.sum - - -class ProductKernel(CombinationKernel): - """A kernel that is the product of a set of kernels.""" - - def __init__( - self, - kernel_set: List[AbstractKernel], - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Product kernel", - ) -> None: - super().__init__( - kernel_set, compute_engine, active_dims, stationary, spectral, name - ) - self.combination_fn: Optional[Callable] = jnp.prod - - -########################################## -# Euclidean kernels -########################################## -@deprecation.deprecated( - deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the RBF kernel" -) -class RBF(AbstractKernel): - """The Radial Basis Function (RBF) kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Radial basis function kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - params = { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Matern12 kernel", -) -class Matern12(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 0.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matérn 1/2 kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Matern32 kernel", -) -class Matern32(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 1.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 3/2", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(3.0) * tau) - * jnp.exp(-jnp.sqrt(3.0) * tau) - ) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Matern52 kernel", -) -class Matern52(AbstractKernel): - """The Matérn kernel with smoothness parameter fixed at 2.5.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Matern 5/2", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with - lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - tau = euclidean_distance(x, y) - K = ( - params["variance"] - * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) - * jnp.exp(-jnp.sqrt(5.0) * tau) - ) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the PoweredExponential kernel", -) -class PoweredExponential(AbstractKernel): - """The powered exponential family of kernels. - - Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". - - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Powered exponential", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell`, :math:`\sigma` and power :math:`\kappa`. - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "power": jnp.array([1.0]), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Linear kernel", -) -class Linear(AbstractKernel): - """The linear kernel.""" - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Linear", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\sigma` - - .. math:: - k(x, y) = \\sigma^2 x^{T}y - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - K = params["variance"] * jnp.matmul(x.T, y) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return {"variance": jnp.array([1.0])} - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Polynomial kernel", -) -class Polynomial(AbstractKernel): - """The Polynomial kernel with variable degree.""" - - def __init__( - self, - degree: int = 1, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Polynomial", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.degree = degree - self.name = f"Polynomial Degree: {self.degree}" - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\sigma^2` through - - .. math:: - k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - x = self.slice_input(x).squeeze() - y = self.slice_input(y).squeeze() - K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "shift": jnp.array([1.0]), - "variance": jnp.array([1.0] * self.ndims), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the White kernel", -) -class White(AbstractKernel): - def __init__( - self, - compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "White", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __post_init__(self) -> None: - super(White, self).__post_init__() - - def __call__( - self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] - ) -> Float[Array, "1"]: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\sigma` - - .. math:: - k(x, y) = \\sigma^2 \delta(x-y) - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. - y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. - - Returns: - Float[Array, "1"]: The value of :math:`k(x, y)`. - """ - K = jnp.all(jnp.equal(x, y)) * params["variance"] - return K.squeeze() - - def _initialise_params(self, key: Float[Array, "1 D"]) -> Dict: - """Initialise the kernel parameters. - - Args: - key (Float[Array, "1 D"]): The key to initialise the parameters with. - - Returns: - Dict: The initialised parameters. - """ - return {"variance": jnp.array([1.0])} - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the RationalQuadratic kernel", -) -class RationalQuadratic(AbstractKernel): - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Rational Quadratic", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) / params["lengthscale"] - y = self.slice_input(y) / params["lengthscale"] - K = params["variance"] * ( - 1 + 0.5 * squared_distance(x, y) / params["alpha"] - ) ** (-params["alpha"]) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "alpha": jnp.array([1.0]), - } - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the Periodic kernel", -) -class Periodic(AbstractKernel): - """The periodic kernel. - - Key reference is MacKay 1998 - "Introduction to Gaussian processes". - """ - - def __init__( - self, - compute_engine: AbstractKernelComputation = DenseKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Periodic", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - - def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: - """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` - - .. math:: - k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) - - Args: - x (jax.Array): The left hand argument of the kernel function's call. - y (jax.Array): The right hand argument of the kernel function's call - params (dict): Parameter set for which the kernel should be evaluated on. - Returns: - Array: The value of :math:`k(x, y)` - """ - x = self.slice_input(x) - y = self.slice_input(y) - sine_squared = ( - jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] - ) ** 2 - K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) - return K.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "period": jnp.array([1.0] * self.ndims), - } - - -########################################## -# Graph kernels -########################################## -class EigenKernelComputation(AbstractKernelComputation): - def __init__( - self, - kernel_fn: Callable[ - [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array - ] = None, - ) -> None: - super().__init__(kernel_fn) - self._eigenvalues = None - self._eigenvectors = None - self._num_verticies = None - - # Define an eigenvalue setter and getter property - @property - def eigensystem(self) -> Float[Array, "N"]: - return self._eigenvalues, self._eigenvectors, self._num_verticies - - @eigensystem.setter - def eigensystem( - self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] - ) -> None: - self._eigenvalues = eigenvalues - self._eigenvectors = eigenvectors - - @property - def num_vertex(self) -> int: - return self._num_verticies - - @num_vertex.setter - def num_vertex(self, num_vertex: int) -> None: - self._num_verticies = num_vertex - - def _compute_S(self, params): - evals, evecs = self.eigensystem - S = jnp.power( - evals - + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], - -params["smoothness"], - ) - S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) - S = jnp.multiply(S, params["variance"]) - return S - - def cross_covariance( - self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] - ) -> Float[Array, "N M"]: - S = self._compute_S(params=params) - matrix = self.kernel_fn(params, x, y, S=S) - return matrix - - -@deprecation.deprecated( - deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" -) -class GraphKernel(AbstractKernel): - def __init__( - self, - laplacian: Float[Array, "N N"], - compute_engine: EigenKernelComputation = EigenKernelComputation, - active_dims: Optional[List[int]] = None, - stationary: Optional[bool] = False, - spectral: Optional[bool] = False, - name: Optional[str] = "Graph kernel", - ) -> None: - super().__init__(compute_engine, active_dims, stationary, spectral, name) - self.laplacian = laplacian - evals, self.evecs = jnp.linalg.eigh(self.laplacian) - self.evals = evals.reshape(-1, 1) - self.compute_engine.eigensystem = self.evals, self.evecs - self.compute_engine.num_vertex = self.laplacian.shape[0] - - def __call__( - self, - params: Dict, - x: Float[Array, "1 D"], - y: Float[Array, "1 D"], - **kwargs, - ) -> Float[Array, "1"]: - """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. - - Args: - params (Dict): Parameter set for which the kernel should be evaluated on. - x (Float[Array, "1 D"]): Index of the ith vertex. - y (Float[Array, "1 D"]): Index of the jth vertex. - - Returns: - Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. - """ - S = kwargs["S"] - Kxx = (jax_gather_nd(self.evecs, x) * S[None, :]) @ jnp.transpose( - jax_gather_nd(self.evecs, y) - ) # shape (n,n) - return Kxx.squeeze() - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return { - "lengthscale": jnp.array([1.0] * self.ndims), - "variance": jnp.array([1.0]), - "smoothness": jnp.array([1.0]), - } - - @property - def num_vertex(self) -> int: - return self.compute_engine.num_vertex - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the squared_distance function", -) -def squared_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] -) -> Float[Array, "1"]: - """Compute the squared distance between a pair of inputs. - - Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. - - Returns: - Float[Array, "1"]: The squared distance between the inputs. - """ - - return jnp.sum((x - y) ** 2) - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the euclidean_distance function", -) -def euclidean_distance( - x: Float[Array, "1 D"], y: Float[Array, "1 D"] -) -> Float[Array, "1"]: - """Compute the euclidean distance between a pair of inputs. - - Args: - x (Float[Array, "1 D"]): First input. - y (Float[Array, "1 D"]): Second input. - - Returns: - Float[Array, "1"]: The euclidean distance between the inputs. - """ - - return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) - - -@deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the jax_gather_nd function", -) -def jax_gather_nd(params, indices): - tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) - return params[tuple_indices] +def deprecate(cls): + return deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the " + cls.__name__, + )(cls) + + +AbstractKernelComputation = deprecate(jaxkern.kernels.AbstractKernelComputation) +DiagonalKernelComputation = deprecate(jaxkern.kernels.DiagonalKernelComputation) +ConstantDiagonalKernelComputation = deprecate( + jaxkern.kernels.ConstantDiagonalKernelComputation +) +AbstractKernel = deprecate(jaxkern.kernels.AbstractKernel) +CombinationKernel = deprecate(jaxkern.kernels.CombinationKernel) +SumKernel = deprecate(jaxkern.kernels.SumKernel) +ProductKernel = deprecate(jaxkern.kernels.ProductKernel) +RBF = deprecate(jaxkern.kernels.RBF) +Matern12 = deprecate(jaxkern.kernels.Matern12) +Matern32 = deprecate(jaxkern.kernels.Matern32) +Matern52 = deprecate(jaxkern.kernels.Matern52) +Linear = deprecate(jaxkern.kernels.Linear) +Periodic = deprecate(jaxkern.kernels.Periodic) +White = deprecate(jaxkern.kernels.White) +PoweredExponential = deprecate(jaxkern.kernels.PoweredExponential) +RationalQuadratic = deprecate(jaxkern.kernels.RationalQuadratic) +Polynomial = deprecate(jaxkern.kernels.Polynomial) +EigenKernelComputation = deprecate(jaxkern.kernels.EigenKernelComputation) +GraphKernel = deprecate(jaxkern.kernels.GraphKernel) +squared_distance = deprecate(jaxkern.kernels.squared_distance) +euclidean_distance = deprecate(jaxkern.kernels.euclidean_distance) +jax_gather_nd = deprecate(jaxkern.kernels.jax_gather_nd) __all__ = [ diff --git a/gpjax/likelihoods.py b/gpjax/likelihoods.py index 6b39c1d2..898099f6 100644 --- a/gpjax/likelihoods.py +++ b/gpjax/likelihoods.py @@ -16,22 +16,28 @@ import abc from typing import Any, Callable, Dict, Optional from jaxlinop.utils import to_dense +from jaxutils import PyTree import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass from jaxtyping import Array, Float from jax.random import KeyArray -@dataclass -class AbstractLikelihood: +class AbstractLikelihood(PyTree): """Abstract base class for likelihoods.""" - num_datapoints: int # The number of datapoints that the likelihood factorises over. - name: Optional[str] = "Likelihood" + def __init__(self, num_datapoints: int, name: Optional[str] = None): + """Initialise the likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to None. + """ + self.num_datapoints = num_datapoints + self.name = name def __call__(self, *args: Any, **kwargs: Any) -> dx.Distribution: """Evaluate the likelihood function at a given predictive distribution. @@ -81,22 +87,28 @@ def link_function(self) -> Callable: raise NotImplementedError -@dataclass class Conjugate: """An abstract class for conjugate likelihoods with respect to a Gaussain process prior.""" -@dataclass class NonConjugate: """An abstract class for non-conjugate likelihoods with respect to a Gaussain process prior.""" -# TODO: revamp this will covariance operators. -@dataclass +# TODO: revamp this with covariance operators. + + class Gaussian(AbstractLikelihood, Conjugate): """Gaussian likelihood object.""" - name: Optional[str] = "Gaussian" + def __init__(self, num_datapoints: int, name: Optional[str] = "Gaussian"): + """Initialise the Gaussian likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to "Gaussian". + """ + super().__init__(num_datapoints, name) def _initialise_params(self, key: KeyArray) -> Dict: """Return the variance parameter of the likelihood function. @@ -157,9 +169,15 @@ def predict(self, params: Dict, dist: dx.MultivariateNormalTri) -> dx.Distributi return dx.MultivariateNormalFullCovariance(dist.mean(), noisy_cov) -@dataclass class Bernoulli(AbstractLikelihood, NonConjugate): - name: Optional[str] = "Bernoulli" + def __init__(self, num_datapoints: int, name: Optional[str] = "Bernoulli"): + """Initialise the Bernoulli likelihood. + + Args: + num_datapoints (int): The number of datapoints that the likelihood factorises over. + name (Optional[str]): The name of the likelihood. Defaults to "Bernoulli". + """ + super().__init__(num_datapoints, name) def _initialise_params(self, key: KeyArray) -> Dict: """Initialise the parameter set of a Bernoulli likelihood. diff --git a/gpjax/mean_functions.py b/gpjax/mean_functions.py index 29ec93b1..65902640 100644 --- a/gpjax/mean_functions.py +++ b/gpjax/mean_functions.py @@ -17,16 +17,25 @@ from typing import Dict, Optional import jax.numpy as jnp -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float +from jaxutils import PyTree -@dataclass(repr=False) -class AbstractMeanFunction: +class AbstractMeanFunction(PyTree): """Abstract mean function that is used to parameterise the Gaussian process.""" - output_dim: Optional[int] = 1 - name: Optional[str] = "Mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + self.output_dim = output_dim + self.name = name @abc.abstractmethod def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: @@ -42,11 +51,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: raise NotImplementedError @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the parameters of the mean function. This method is required for all subclasses. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. @@ -54,14 +63,21 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: raise NotImplementedError -@dataclass(repr=False) class Zero(AbstractMeanFunction): """ A zero mean function. This function returns zero for all inputs. """ - output_dim: Optional[int] = 1 - name: Optional[str] = "Zero mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the zero-mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + super().__init__(output_dim, name) def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. @@ -76,11 +92,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.zeros(shape=out_shape) - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the zero-mean function, this is an empty dictionary. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. @@ -88,15 +104,22 @@ def _initialise_params(self, key: PRNGKeyType) -> Dict: return {} -@dataclass(repr=False) class Constant(AbstractMeanFunction): """ A zero mean function. This function returns a repeated scalar value for all inputs. The scalar value itself can be treated as a model hyperparameter and learned during training. """ - output_dim: Optional[int] = 1 - name: Optional[str] = "Constant mean function" + def __init__( + self, output_dim: Optional[int] = 1, name: Optional[str] = "Mean function" + ): + """Initialise the constant-mean function. + + Args: + output_dim (Optional[int]): The output dimension of the mean function. Defaults to 1. + name (Optional[str]): The name of the mean function. Defaults to "Mean function". + """ + super().__init__(output_dim, name) def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: """Evaluate the mean function at the given points. @@ -111,11 +134,11 @@ def __call__(self, params: Dict, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: out_shape = (x.shape[0], self.output_dim) return jnp.ones(shape=out_shape) * params["constant"] - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """The parameters of the mean function. For the constant-mean function, this is a dictionary with a single value. Args: - key (PRNGKeyType): The PRNG key to use for initialising the parameters. + key (KeyArray): The PRNG key to use for initialising the parameters. Returns: Dict: The parameters of the mean function. diff --git a/gpjax/parameters.py b/gpjax/parameters.py index 85b01f19..f409905d 100644 --- a/gpjax/parameters.py +++ b/gpjax/parameters.py @@ -22,8 +22,9 @@ import jax import jax.numpy as jnp import jax.random as jr -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float +from jaxutils import PyTree from .config import Identity, get_global_config from .utils import merge_dictionaries @@ -32,17 +33,17 @@ ################################ # Base operations ################################ -@dataclass -class ParameterState: +class ParameterState(PyTree): """ The state of the model. This includes the parameter set, which parameters are to be trained and bijectors that allow parameters to be constrained and unconstrained. """ - params: Dict - trainables: Dict - bijectors: Dict + def __init__(self, params: Dict, trainables: Dict, bijectors: Dict) -> None: + self.params = params + self.trainables = trainables + self.bijectors = bijectors def unpack(self): """Unpack the state into a tuple of parameters, trainables and bijectors. @@ -53,7 +54,7 @@ def unpack(self): return self.params, self.trainables, self.bijectors -def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: +def initialise(model, key: KeyArray = None, **kwargs) -> ParameterState: """ Initialise the stateful parameters of any GPJax object. This function also returns the trainability status of each parameter and set of bijectors that @@ -61,7 +62,7 @@ def initialise(model, key: PRNGKeyType = None, **kwargs) -> ParameterState: Args: model: The GPJax object that is to be initialised. - key (PRNGKeyType, optional): The random key that is to be used for + key (KeyArray, optional): The random key that is to be used for initialisation. Defaults to None. Returns: diff --git a/gpjax/test_variational_inference.py b/gpjax/test_variational_inference.py new file mode 100644 index 00000000..1e7eb9eb --- /dev/null +++ b/gpjax/test_variational_inference.py @@ -0,0 +1,159 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import typing as tp + +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest +from jax.config import config + +import gpjax as gpx +from gpjax.variational_families import ( + CollapsedVariationalGaussian, + ExpectationVariationalGaussian, + NaturalVariationalGaussian, + VariationalGaussian, + WhitenedVariationalGaussian, +) + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) + + +def test_abstract_variational_inference(): + prior = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=20) + post = prior * lik + n_inducing_points = 10 + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + vartiational_family = gpx.VariationalGaussian( + prior=prior, inducing_inputs=inducing_inputs + ) + + with pytest.raises(TypeError): + gpx.variational_inference.AbstractVariationalInference( + posterior=post, vartiational_family=vartiational_family + ) + + +def get_data_and_gp(n_datapoints, point_dim): + x = jnp.linspace(-5.0, 5.0, n_datapoints).reshape(-1, 1) + y = jnp.sin(x) + jr.normal(key=jr.PRNGKey(123), shape=x.shape) * 0.1 + x = jnp.hstack([x] * point_dim) + D = gpx.Dataset(X=x, y=y) + + p = gpx.Prior(kernel=gpx.RBF()) + lik = gpx.Gaussian(num_datapoints=n_datapoints) + post = p * lik + return D, post, p + + +@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) +@pytest.mark.parametrize("jit_fns", [False, True]) +@pytest.mark.parametrize("point_dim", [1, 2, 3]) +@pytest.mark.parametrize( + "variational_family", + [ + VariationalGaussian, + WhitenedVariationalGaussian, + NaturalVariationalGaussian, + ExpectationVariationalGaussian, + ], +) +def test_stochastic_vi( + n_datapoints, n_inducing_points, jit_fns, point_dim, variational_family +): + D, post, prior = get_data_and_gp(n_datapoints, point_dim) + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) + + q = variational_family(prior=prior, inducing_inputs=inducing_inputs) + + svgp = gpx.StochasticVI(posterior=post, variational_family=q) + assert svgp.posterior.prior == post.prior + assert svgp.posterior.likelihood == post.likelihood + + params, _, _ = gpx.initialise(svgp, jr.PRNGKey(123)).unpack() + + assert svgp.prior == post.prior + assert svgp.likelihood == post.likelihood + + if jit_fns: + elbo_fn = jax.jit(svgp.elbo(D)) + else: + elbo_fn = svgp.elbo(D) + assert isinstance(elbo_fn, tp.Callable) + elbo_value = elbo_fn(params, D) + assert isinstance(elbo_value, jnp.ndarray) + + # Test gradients + grads = jax.grad(elbo_fn, argnums=0)(params, D) + assert isinstance(grads, tp.Dict) + assert len(grads) == len(params) + + +@pytest.mark.parametrize("n_datapoints, n_inducing_points", [(10, 2), (100, 10)]) +@pytest.mark.parametrize("jit_fns", [False, True]) +@pytest.mark.parametrize("point_dim", [1, 2]) +def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): + D, post, prior = get_data_and_gp(n_datapoints, point_dim) + likelihood = gpx.Gaussian(num_datapoints=n_datapoints) + + inducing_inputs = jnp.linspace(-5.0, 5.0, n_inducing_points).reshape(-1, 1) + inducing_inputs = jnp.hstack([inducing_inputs] * point_dim) + + q = CollapsedVariationalGaussian( + prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs + ) + + sgpr = gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) + assert sgpr.posterior.prior == post.prior + assert sgpr.posterior.likelihood == post.likelihood + + params, _, _ = gpx.initialise(sgpr, jr.PRNGKey(123)).unpack() + + assert sgpr.prior == post.prior + assert sgpr.likelihood == post.likelihood + + if jit_fns: + elbo_fn = jax.jit(sgpr.elbo(D)) + else: + elbo_fn = sgpr.elbo(D) + assert isinstance(elbo_fn, tp.Callable) + elbo_value = elbo_fn(params) + assert isinstance(elbo_value, jnp.ndarray) + + # Test gradients + grads = jax.grad(elbo_fn)(params) + assert isinstance(grads, tp.Dict) + assert len(grads) == len(params) + + # We should raise an error for non-Collapsed variational families: + with pytest.raises(TypeError): + q = gpx.variational_families.VariationalGaussian( + prior=prior, inducing_inputs=inducing_inputs + ) + gpx.variational_inference.CollapsedVI(posterior=post, variational_family=q) + + # We should raise an error for non-Gaussian likelihoods: + with pytest.raises(TypeError): + q = gpx.variational_families.CollapsedVariationalGaussian( + prior=prior, likelihood=likelihood, inducing_inputs=inducing_inputs + ) + gpx.variational_inference.CollapsedVI( + posterior=prior * gpx.Bernoulli(num_datapoints=D.n), variational_family=q + ) diff --git a/gpjax/types.py b/gpjax/types.py index b3fedb97..d1e8b110 100644 --- a/gpjax/types.py +++ b/gpjax/types.py @@ -13,67 +13,20 @@ # limitations under the License. # ============================================================================== -import jax.numpy as jnp -from chex import dataclass -from jaxtyping import Array, Float +import jaxutils import deprecation -NoneType = type(None) - - -@deprecation.deprecated( +Dataset = deprecation.deprecated( deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxUtils for a Dataset object", -) -@dataclass -class Dataset: - """GPJax Dataset class.""" - - X: Float[Array, "N D"] - y: Float[Array, "N 1"] = None - - def __repr__(self) -> str: - return ( - f"- Number of datapoints: {self.X.shape[0]}\n- Dimension:" - f" {self.X.shape[1]}" - ) - - def __add__(self, other: "Dataset") -> "Dataset": - """Combines two datasets into one. The right-hand dataset is stacked beneath left.""" - x = jnp.concatenate((self.X, other.X)) - y = jnp.concatenate((self.y, other.y)) - - return Dataset(X=x, y=y) +)(jaxutils.Dataset) - @property - def n(self) -> int: - """The number of observations in the dataset.""" - return self.X.shape[0] - - @property - def in_dim(self) -> int: - """The dimension of the input data.""" - return self.X.shape[1] - - @property - def out_dim(self) -> int: - """The dimension of the output data.""" - return self.y.shape[1] +verify_dataset = deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxUtils for a Dataset object", +)(jaxutils.verify_dataset) -def verify_dataset(ds: Dataset) -> None: - """Apply a series of checks to the dataset to ensure that downstream operations are safe.""" - assert ds.X.ndim == 2, ( - "2-dimensional training inputs are required. Current dimension:" - f" {ds.X.ndim}." - ) - if ds.y is not None: - assert ds.y.ndim == 2, ( - "2-dimensional training outputs are required. Current dimension:" - f" {ds.y.ndim}." - ) - assert ds.X.shape[0] == ds.y.shape[0], ( - "Number of inputs must equal the number of outputs. \nCurrent" - f" counts:\n- X: {ds.X.shape[0]}\n- y: {ds.y.shape[0]}" - ) +__all__ = ["Dataset" "verify_dataset"] diff --git a/gpjax/variational_families.py b/gpjax/variational_families.py index e9c6ecd0..ac7e6ccc 100644 --- a/gpjax/variational_families.py +++ b/gpjax/variational_families.py @@ -19,11 +19,11 @@ import distrax as dx import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass, PRNGKey as PRNGKeyType +from jax.random import KeyArray from jaxtyping import Array, Float from jaxlinop import identity -from jaxutils import Dataset +from jaxutils import PyTree, Dataset import jaxlinop as jlo from .config import get_global_config @@ -33,8 +33,7 @@ from .gaussian_distribution import GaussianDistribution -@dataclass -class AbstractVariationalFamily: +class AbstractVariationalFamily(PyTree): """ Abstract base class used to represent families of distributions that can be used within variational inference. @@ -55,13 +54,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution: return self.predict(*args, **kwargs) @abc.abstractmethod - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """ The parameters of the distribution. For example, the multivariate Gaussian would return a mean vector and covariance matrix. Args: - key (PRNGKeyType): The PRNG key used to initialise the parameters. + key (KeyArray): The PRNG key used to initialise the parameters. Returns: Dict: The parameters of the distribution. @@ -84,20 +83,27 @@ def predict(self, *args: Any, **kwargs: Any) -> GaussianDistribution: raise NotImplementedError -@dataclass class AbstractVariationalGaussian(AbstractVariationalFamily): """The variational Gaussian family of probability distributions.""" - prior: Prior - inducing_inputs: Float[Array, "N D"] - name: str = "Gaussian" - - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Variational Gaussian", + ) -> None: + """ + Args: + prior (Prior): The prior distribution. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. Defaults to "Gaussian". + """ + self.prior = prior + self.inducing_inputs = inducing_inputs self.num_inducing = self.inducing_inputs.shape[0] + self.name = name -@dataclass class VariationalGaussian(AbstractVariationalGaussian): """The variational Gaussian family of probability distributions. @@ -108,14 +114,14 @@ class VariationalGaussian(AbstractVariationalGaussian): :math:`\\mu` and sqrt with S = sqrt sqrtᵀ. """ - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """ Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution. Args: - key (PRNGKeyType): The PRNG key used to initialise the parameters. + key (KeyArray): The PRNG key used to initialise the parameters. Returns: Dict: The parameters of the distribution. @@ -250,7 +256,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class WhitenedVariationalGaussian(VariationalGaussian): """ The whitened variational Gaussian family of probability distributions. @@ -262,7 +267,21 @@ class WhitenedVariationalGaussian(VariationalGaussian): """ - name: str = "Whitened variational Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Whitened variational Gaussian", + ) -> None: + """Initialise the whitened variational Gaussian family. + + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) def prior_kl(self, params: Dict) -> Float[Array, "1"]: """Compute the KL-divergence between our variational approximation and @@ -355,7 +374,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class NaturalVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -363,12 +381,25 @@ class NaturalVariationalGaussian(AbstractVariationalGaussian): and the distribution over the inducing inputs is q(u) = N(μ, S). Expressing the variational distribution, in the form of the exponential family, q(u) = exp(θᵀ T(u) - a(θ)), gives rise to the natural paramerisation θ = (θ₁, θ₂) = (S⁻¹μ, -S⁻¹/2), to perform model inference, where T(u) = [u, uuᵀ] are the sufficient statistics. - """ - name: str = "Natural Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Natural variational Gaussian", + ) -> None: + """Initialise the natural variational Gaussian family. - def _initialise_params(self, key: PRNGKeyType) -> Dict: + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) + + def _initialise_params(self, key: KeyArray) -> Dict: """Return the natural vector and matrix, inducing inputs, and hyperparameters that parameterise the natural Gaussian distribution.""" m = self.num_inducing @@ -527,7 +558,6 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> dx.MultivariateNormalTri: return predict_fn -@dataclass class ExpectationVariationalGaussian(AbstractVariationalGaussian): """The natural variational Gaussian family of probability distributions. @@ -538,9 +568,23 @@ class ExpectationVariationalGaussian(AbstractVariationalGaussian): η = (η₁, η₁) = (μ, S + uuᵀ) to perform model inference over. """ - name: str = "Expectation Gaussian" + def __init__( + self, + prior: Prior, + inducing_inputs: Float[Array, "N D"], + name: Optional[str] = "Expectation variational Gaussian", + ) -> None: + """Initialise the expectation variational Gaussian family. + + Args: + prior (Prior): The GP prior. + inducing_inputs (Float[Array, "N D"]): The inducing inputs. + name (Optional[str]): The name of the variational family. + """ + + super().__init__(prior, inducing_inputs, name) - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Return the expectation vector and matrix, inducing inputs, and hyperparameters that parameterise the expectation Gaussian distribution.""" self.num_inducing = self.inducing_inputs.shape[0] @@ -691,25 +735,36 @@ def predict_fn(test_inputs: Float[Array, "N D"]) -> GaussianDistribution: return predict_fn -@dataclass class CollapsedVariationalGaussian(AbstractVariationalFamily): """Collapsed variational Gaussian family of probability distributions. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - prior: Prior - likelihood: AbstractLikelihood - inducing_inputs: Float[Array, "M D"] - name: str = "Collapsed variational Gaussian" - diag: Optional[bool] = False + def __init__( + self, + prior: Prior, + likelihood: AbstractLikelihood, + inducing_inputs: Float[Array, "M D"], + name: str = "Collapsed variational Gaussian", + ): + """Initialise the collapsed variational Gaussian family of probability distributions. - def __post_init__(self): - """Initialise the variational Gaussian distribution.""" - self.num_inducing = self.inducing_inputs.shape[0] + Args: + prior (Prior): The prior distribution that we are approximating. + likelihood (AbstractLikelihood): The likelihood function that we are using to model the data. + inducing_inputs (Float[Array, "M D"]): The inducing inputs that are to be used to parameterise the variational Gaussian distribution. + name (str, optional): The name of the variational family. Defaults to "Collapsed variational Gaussian". + """ - if not isinstance(self.likelihood, Gaussian): + if not isinstance(likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - def _initialise_params(self, key: PRNGKeyType) -> Dict: + self.prior = prior + self.likelihood = likelihood + self.inducing_inputs = inducing_inputs + self.num_inducing = self.inducing_inputs.shape[0] + self.name = name + + def _initialise_params(self, key: KeyArray) -> Dict: """Return the variational mean vector, variational root covariance matrix, and inducing input vector that parameterise the variational Gaussian distribution.""" return concat_dictionaries( self.prior._initialise_params(key), diff --git a/gpjax/variational_inference.py b/gpjax/variational_inference.py index 34f493d1..6745656e 100644 --- a/gpjax/variational_inference.py +++ b/gpjax/variational_inference.py @@ -18,11 +18,12 @@ import jax.numpy as jnp import jax.scipy as jsp -from chex import dataclass, PRNGKey as PRNGKeyType from jax import vmap from jaxtyping import Array, Float from jaxlinop import identity +from jax.random import KeyArray +from jaxutils import PyTree from .config import get_global_config from .gps import AbstractPosterior @@ -36,18 +37,26 @@ ) -@dataclass -class AbstractVariationalInference: +class AbstractVariationalInference(PyTree): """A base class for inference and training of variational families against an extact posterior""" - posterior: AbstractPosterior - variational_family: AbstractVariationalFamily + def __init__( + self, + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + ) -> None: + """Initialise the variational inference module. - def __post_init__(self): + Args: + posterior (AbstractPosterior): The exact posterior distribution. + variational_family (AbstractVariationalFamily): The variational family to be trained. + """ + self.posterior = posterior self.prior = self.posterior.prior self.likelihood = self.posterior.likelihood + self.variational_family = variational_family - def _initialise_params(self, key: PRNGKeyType) -> Dict: + def _initialise_params(self, key: KeyArray) -> Dict: """Construct the parameter set used within the variational scheme adopted.""" hyperparams = concat_dictionaries( {"likelihood": self.posterior.likelihood._initialise_params(key)}, @@ -71,15 +80,9 @@ def elbo( raise NotImplementedError -@dataclass class StochasticVI(AbstractVariationalInference): """Stochastic Variational inference training module. The key reference is Hensman et. al., (2013) - Gaussian processes for big data.""" - def __post_init__(self): - self.prior = self.posterior.prior - self.likelihood = self.posterior.likelihood - self.num_inducing = self.variational_family.num_inducing - def elbo( self, train_data: Dataset, negative: bool = False ) -> Callable[[Float[Array, "N D"]], Float[Array, "1"]]: @@ -144,22 +147,30 @@ def q_moments(x): return expectation -@dataclass class CollapsedVI(AbstractVariationalInference): """Collapsed variational inference for a sparse Gaussian process regression model. The key reference is Titsias, (2009) - Variational Learning of Inducing Variables in Sparse Gaussian Processes.""" - def __post_init__(self): - self.prior = self.posterior.prior - self.likelihood = self.posterior.likelihood - self.num_inducing = self.variational_family.num_inducing + def __init__( + self, + posterior: AbstractPosterior, + variational_family: AbstractVariationalFamily, + ) -> None: + """Initialise the variational inference module. + + Args: + posterior (AbstractPosterior): The exact posterior distribution. + variational_family (AbstractVariationalFamily): The variational family to be trained. + """ - if not isinstance(self.likelihood, Gaussian): + if not isinstance(posterior.likelihood, Gaussian): raise TypeError("Likelihood must be Gaussian.") - if not isinstance(self.variational_family, CollapsedVariationalGaussian): + if not isinstance(variational_family, CollapsedVariationalGaussian): raise TypeError("Variational family must be CollapsedVariationalGaussian.") + super().__init__(posterior, variational_family) + def elbo( self, train_data: Dataset, negative: bool = False ) -> Callable[[Dict], Float[Array, "1"]]: @@ -180,7 +191,7 @@ def elbo( mean_function = self.prior.mean_function kernel = self.prior.kernel - m = self.num_inducing + m = self.variational_family.num_inducing jitter = get_global_config()["jitter"] # Constant for whether or not to negate the elbo for optimisation purposes diff --git a/tests/test_abstractions.py b/tests/test_abstractions.py index c9ba5c38..4b7fce30 100644 --- a/tests/test_abstractions.py +++ b/tests/test_abstractions.py @@ -28,10 +28,10 @@ config.update("jax_enable_x64", True) -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("n", [1, 20]) @pytest.mark.parametrize("verbose", [True, False]) -def test_fit(n_iters, n, verbose): +def test_fit(num_iters, n, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(n, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -41,13 +41,13 @@ def test_fit(n_iters, n, verbose): mll = p.marginal_log_likelihood(D, negative=True) pre_mll_val = mll(parameter_state.params) optimiser = optax.adam(learning_rate=0.1) - inference_state = fit(mll, parameter_state, optimiser, n_iters, verbose=verbose) + inference_state = fit(mll, parameter_state, optimiser, num_iters, verbose=verbose) optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert mll(optimised_params) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters def test_stop_grads(): @@ -59,18 +59,18 @@ def test_stop_grads(): parameter_state = ParameterState( params=params, trainables=trainables, bijectors=bijectors ) - inference_state = fit(loss_fn, parameter_state, optimiser, n_iters=1) + inference_state = fit(loss_fn, parameter_state, optimiser, num_iters=1) learned_params = inference_state.params assert isinstance(inference_state, InferenceState) assert learned_params["y"] == params["y"] assert learned_params["x"] != params["x"] -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("nb", [1, 20, 50]) @pytest.mark.parametrize("ndata", [50]) @pytest.mark.parametrize("verbose", [True, False]) -def test_batch_fitting(n_iters, nb, ndata, verbose): +def test_batch_fitting(num_iters, nb, ndata, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -93,21 +93,21 @@ def test_batch_fitting(n_iters, nb, ndata, verbose): optimiser = optax.adam(learning_rate=0.1) key = jr.PRNGKey(42) inference_state = fit_batches( - objective, parameter_state, D, optimiser, key, nb, n_iters, verbose=verbose + objective, parameter_state, D, optimiser, key, nb, num_iters, verbose=verbose ) optimised_params, history = inference_state.unpack() assert isinstance(inference_state, InferenceState) assert isinstance(optimised_params, dict) assert objective(optimised_params, D) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters -@pytest.mark.parametrize("n_iters", [1, 5]) +@pytest.mark.parametrize("num_iters", [1, 5]) @pytest.mark.parametrize("nb", [1, 20, 50]) @pytest.mark.parametrize("ndata", [50]) @pytest.mark.parametrize("verbose", [True, False]) -def test_natural_gradients(ndata, nb, n_iters, verbose): +def test_natural_gradients(ndata, nb, num_iters, verbose): key = jr.PRNGKey(123) x = jnp.sort(jr.uniform(key=key, minval=-2.0, maxval=2.0, shape=(ndata, 1)), axis=0) y = jnp.sin(x) + jr.normal(key=key, shape=x.shape) * 0.1 @@ -140,7 +140,7 @@ def test_natural_gradients(ndata, nb, n_iters, verbose): hyper_optimiser, key, nb, - n_iters, + num_iters, verbose=verbose, ) optimised_params, history = inference_state.unpack() @@ -148,7 +148,7 @@ def test_natural_gradients(ndata, nb, n_iters, verbose): assert isinstance(optimised_params, dict) assert objective(optimised_params, D) < pre_mll_val assert isinstance(history, jnp.ndarray) - assert history.shape[0] == n_iters + assert history.shape[0] == num_iters @pytest.mark.parametrize("batch_size", [1, 2, 50]) diff --git a/tests/test_kernels.py b/tests/test_kernels.py deleted file mode 100644 index bb7bafe4..00000000 --- a/tests/test_kernels.py +++ /dev/null @@ -1,599 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - - -from itertools import permutations -from typing import Dict, List - -import jax -import jax.numpy as jnp -import jax.random as jr -import networkx as nx -import pytest -from jax.config import config -from jaxtyping import Array, Float -from chex import PRNGKey as PRNGKeyType - -from jaxlinop import ( - LinearOperator, - identity, -) - -from gpjax.kernels import ( - RBF, - Linear, - RationalQuadratic, - CombinationKernel, - GraphKernel, - AbstractKernel, - Matern12, - Matern32, - Matern52, - Polynomial, - PoweredExponential, - ProductKernel, - Periodic, - SumKernel, - euclidean_distance, -) -from gpjax.parameters import initialise - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) -_initialise_key = jr.PRNGKey(123) -_jitter = 1e-6 - - -def test_abstract_kernel(): - # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: - with pytest.raises(TypeError): - AbstractKernel() - - # Create a dummy kernel class with __call__ and _init_params methods implemented: - class DummyKernel(AbstractKernel): - def __call__( - self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict - ) -> Float[Array, "1"]: - return x * params["test"] * y - - def _initialise_params(self, key: PRNGKeyType) -> Dict: - return {"test": 1.0} - - # Initialise dummy kernel class and test __call__ and _init_params methods: - dummy_kernel = DummyKernel() - assert dummy_kernel._initialise_params(_initialise_key) == {"test": 1.0} - assert dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 - - -@pytest.mark.parametrize( - "a, b, distance_to_3dp", - [ - ([1.0], [-4.0], 5.0), - ([1.0, -2.0], [-4.0, 3.0], 7.071), - ([1.0, 2.0, 3.0], [1.0, 1.0, 1.0], 2.236), - ], -) -def test_euclidean_distance( - a: List[float], b: List[float], distance_to_3dp: float -) -> None: - - # Convert lists to JAX arrays: - a: Float[Array, "D"] = jnp.array(a) - b: Float[Array, "D"] = jnp.array(b) - - # Test distance is correct to 3dp: - assert jnp.round(euclidean_distance(a, b), 3) == distance_to_3dp - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("n", [1, 2, 10]) -def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: - - # Gram constructor static method: - kernel.gram - - # Inputs x: - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Default kernel parameters: - params = kernel._initialise_params(_initialise_key) - - # Test gram matrix: - Kxx = kernel.gram(params, x) - assert isinstance(Kxx, LinearOperator) - assert Kxx.shape == (n, n) - - -@pytest.mark.parametrize( - "kernel", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize("num_a", [1, 2, 5]) -@pytest.mark.parametrize("num_b", [1, 2, 5]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_cross_covariance( - kernel: AbstractKernel, num_a: int, num_b: int, dim: int -) -> None: - # Inputs a, b: - a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) - b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) - - # Default kernel parameters: - params = kernel._initialise_params(_initialise_key) - - # Test cross covariance, Kab: - Kab = kernel.cross_covariance(params, a, b) - assert isinstance(Kab, jnp.ndarray) - assert Kab.shape == (num_a, num_b) - - -@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -def test_call(kernel: AbstractKernel, dim: int) -> None: - - # Datapoint x and datapoint y: - x = jnp.array([[1.0] * dim]) - y = jnp.array([[0.5] * dim]) - - # Defualt parameters: - params = kernel._initialise_params(_initialise_key) - - # Test calling gives an autocovariance value of no dimension between the inputs: - kxy = kernel(params, x, y) - - assert isinstance(kxy, jax.Array) - assert kxy.shape == () - - -@pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def( - kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int -) -> None: - kern = kern(active_dims=list(range(dim))) - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("kern", [Linear, Polynomial]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0]) -@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_lin_poly( - kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int -) -> None: - kern = kern(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = {"variance": jnp.array([sigma]), "shift": jnp.array([shift])} - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: - kern = RationalQuadratic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "alpha": jnp.array([alpha]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_power_exp( - dim: int, ell: float, sigma: float, power: float, n: int -) -> None: - kern = PoweredExponential(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "power": jnp.array([power]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) -@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_pos_def_periodic( - dim: int, ell: float, sigma: float, period: float, n: int -) -> None: - kern = Periodic(active_dims=list(range(dim))) - # Gram constructor static method: - kern.gram - - # Create inputs x: - x = jr.uniform(_initialise_key, (n, dim)) - params = { - "lengthscale": jnp.array([ell]), - "variance": jnp.array([sigma]), - "period": jnp.array([period]), - } - - # Test gram matrix eigenvalues are positive: - Kxx = kern.gram(params, x) - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0.0).all() - - -@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) -@pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) -def test_initialisation(kernel: AbstractKernel, dim: int) -> None: - - if dim is None: - kern = kernel() - assert kern.ndims == 1 - - else: - kern = kernel(active_dims=[i for i in range(dim)]) - params = kern._initialise_params(_initialise_key) - - assert list(params.keys()) == ["lengthscale", "variance"] - assert all(params["lengthscale"] == jnp.array([1.0] * dim)) - assert params["variance"] == jnp.array([1.0]) - - if dim > 1: - assert kern.ard - else: - assert not kern.ard - - -@pytest.mark.parametrize( - "kernel", - [ - RBF, - Matern12, - Matern32, - Matern52, - Linear, - Polynomial, - RationalQuadratic, - PoweredExponential, - Periodic, - ], -) -def test_dtype(kernel: AbstractKernel) -> None: - parameter_state = initialise(kernel(), _initialise_key) - params, *_ = parameter_state.unpack() - for k, v in params.items(): - assert v.dtype == jnp.float64 - assert isinstance(k, str) - - -@pytest.mark.parametrize("degree", [1, 2, 3]) -@pytest.mark.parametrize("dim", [1, 2, 5]) -@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0]) -@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0]) -@pytest.mark.parametrize("n", [1, 2, 5]) -def test_polynomial( - degree: int, dim: int, variance: float, shift: float, n: int -) -> None: - - # Define inputs - x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) - - # Define kernel - kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) - - # Check name - assert kern.name == f"Polynomial Degree: {degree}" - - # Initialise parameters - params = kern._initialise_params(_initialise_key) - params["shift"] * shift - params["variance"] * variance - - # Check parameter keys - assert list(params.keys()) == ["shift", "variance"] - - # Compute gram matrix - Kxx = kern.gram(params, x) - - # Check shapes - assert Kxx.shape[0] == x.shape[0] - assert Kxx.shape[0] == Kxx.shape[1] - - # Test positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0).all() - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, Linear, Polynomial, RationalQuadratic], -) -def test_active_dim(kernel: AbstractKernel) -> None: - dim_list = [0, 1, 2, 3] - perm_length = 2 - dim_pairs = list(permutations(dim_list, r=perm_length)) - n_dims = len(dim_list) - - # Generate random inputs - x = jr.normal(_initialise_key, shape=(20, n_dims)) - - for dp in dim_pairs: - # Take slice of x - slice = x[..., dp] - - # Define kernels - ad_kern = kernel(active_dims=dp) - manual_kern = kernel(active_dims=[i for i in range(perm_length)]) - - # Get initial parameters - ad_params = ad_kern._initialise_params(_initialise_key) - manual_params = manual_kern._initialise_params(_initialise_key) - - # Compute gram matrices - ad_Kxx = ad_kern.gram(ad_params, x) - manual_Kxx = manual_kern.gram(manual_params, slice) - - # Test gram matrices are equal - assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) - - -@pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) -@pytest.mark.parametrize( - "kernel", - [RBF, RationalQuadratic, Linear, Matern12, Matern32, Matern52, Polynomial], -) -@pytest.mark.parametrize("n_kerns", [2, 3, 4]) -def test_combination_kernel( - combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int -) -> None: - - # Create inputs - n = 20 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create list of kernels - kernel_set = [kernel() for _ in range(n_kerns)] - - # Create combination kernel - combination_kernel = combination_type(kernel_set=kernel_set) - - # Initialise default parameters - params = combination_kernel._initialise_params(_initialise_key) - - # Check params are a list of dictionaries - assert len(params) == n_kerns - - for p in params: - assert isinstance(p, dict) - - # Check combination kernel set - assert len(combination_kernel.kernel_set) == n_kerns - assert isinstance(combination_kernel.kernel_set, list) - assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) - - # Compute gram matrix - Kxx = combination_kernel.gram(params, x) - - # Check shapes - assert Kxx.shape[0] == Kxx.shape[1] - assert Kxx.shape[1] == n - - # Check positive definiteness - Kxx += identity(n) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert (eigen_values > 0).all() - - -@pytest.mark.parametrize( - "k1", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] -) -@pytest.mark.parametrize( - "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] -) -def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: - # Create inputs - n = 10 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create sum kernel - sum_kernel = SumKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = sum_kernel._initialise_params(_initialise_key) - - # Compute gram matrix - Kxx = sum_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - # Initialise default parameters - k1_params = k1._initialise_params(_initialise_key) - k2_params = k2._initialise_params(_initialise_key) - - # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) - - # Check manual and automatic gram matrices are equal - assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) - - -@pytest.mark.parametrize( - "k1", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Polynomial(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -@pytest.mark.parametrize( - "k2", - [ - RBF(), - Matern12(), - Matern32(), - Matern52(), - Polynomial(), - Linear(), - Polynomial(), - RationalQuadratic(), - ], -) -def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: - - # Create inputs - n = 10 - x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) - - # Create product kernel - prod_kernel = ProductKernel(kernel_set=[k1, k2]) - - # Initialise default parameters - params = prod_kernel._initialise_params(_initialise_key) - - # Compute gram matrix - Kxx = prod_kernel.gram(params, x) - - # NOW we do the same thing manually and check they are equal: - - # Initialise default parameters - k1_params = k1._initialise_params(_initialise_key) - k2_params = k2._initialise_params(_initialise_key) - - # Compute gram matrix - Kxx_k1 = k1.gram(k1_params, x) - Kxx_k2 = k2.gram(k2_params, x) - - # Check manual and automatic gram matrices are equal - assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) - - -def test_graph_kernel(): - # Create a random graph, G, and verice labels, x, - n_verticies = 20 - n_edges = 40 - G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) - x = jnp.arange(n_verticies).reshape(-1, 1) - - # Compute graph laplacian - L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 - - # Create graph kernel - kern = GraphKernel(laplacian=L) - assert kern.num_vertex == n_verticies - assert kern.evals.shape == (n_verticies, 1) - assert kern.evecs.shape == (n_verticies, n_verticies) - - # Unpack kernel computation - kern.gram - - # Initialise default parameters - params = kern._initialise_params(_initialise_key) - assert isinstance(params, dict) - assert list(sorted(list(params.keys()))) == [ - "lengthscale", - "smoothness", - "variance", - ] - - # Compute gram matrix - Kxx = kern.gram(params, x) - assert Kxx.shape == (n_verticies, n_verticies) - - # Check positive definiteness - Kxx += identity(n_verticies) * _jitter - eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) - assert all(eigen_values > 0) - - -@pytest.mark.parametrize( - "kernel", - [RBF, Matern12, Matern32, Matern52, Polynomial, Linear, RationalQuadratic], -) -def test_combination_kernel_type(kernel: AbstractKernel) -> None: - prod_kern = kernel() * kernel() - assert isinstance(prod_kern, ProductKernel) - assert isinstance(prod_kern, CombinationKernel) - - add_kern = kernel() + kernel() - assert isinstance(add_kern, SumKernel) - assert isinstance(add_kern, CombinationKernel) diff --git a/tests/test_types.py b/tests/test_types.py deleted file mode 100644 index d4b11a94..00000000 --- a/tests/test_types.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2022 The GPJax Contributors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import jax.numpy as jnp -import pytest -from jax.config import config - -from gpjax.types import Dataset, NoneType, verify_dataset - -# Enable Float64 for more stable matrix inversions. -config.update("jax_enable_x64", True) - - -def test_nonetype(): - assert isinstance(None, NoneType) - - -@pytest.mark.parametrize("n", [1, 10]) -@pytest.mark.parametrize("outd", [1, 2, 10]) -@pytest.mark.parametrize("ind", [1, 2, 10]) -@pytest.mark.parametrize("n2", [1, 10]) -def test_dataset(n, outd, ind, n2): - x = jnp.ones((n, ind)) - y = jnp.ones((n, outd)) - d = Dataset(X=x, y=y) - verify_dataset(d) - assert d.n == n - assert d.in_dim == ind - assert d.out_dim == outd - - # test combine datasets - x2 = 2 * jnp.ones((n2, ind)) - y2 = 2 * jnp.ones((n2, outd)) - d2 = Dataset(X=x2, y=y2) - - d_combined = d + d2 - assert d_combined.n == n + n2 - assert d_combined.in_dim == ind - assert d_combined.out_dim == outd - assert (d_combined.y[:n] == 1.0).all() - assert (d_combined.y[n:] == 2.0).all() - assert (d_combined.X[:n] == 1.0).all() - assert (d_combined.X[n:] == 2.0).all() - - -@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) -def test_dataset_assertions(nx, ny): - x = jnp.ones((nx, 1)) - y = jnp.ones((ny, 1)) - with pytest.raises(AssertionError): - ds = Dataset(X=x, y=y) - verify_dataset(ds) - - -def test_y_none(): - x = jnp.ones((10, 1)) - d = Dataset(X=x) - verify_dataset(d) - assert d.y is None diff --git a/tests/test_variational_inference.py b/tests/test_variational_inference.py index e310f275..1e7eb9eb 100644 --- a/tests/test_variational_inference.py +++ b/tests/test_variational_inference.py @@ -91,7 +91,6 @@ def test_stochastic_vi( assert svgp.prior == post.prior assert svgp.likelihood == post.likelihood - assert svgp.num_inducing == n_inducing_points if jit_fns: elbo_fn = jax.jit(svgp.elbo(D)) @@ -129,7 +128,6 @@ def test_collapsed_vi(n_datapoints, n_inducing_points, jit_fns, point_dim): assert sgpr.prior == post.prior assert sgpr.likelihood == post.likelihood - assert sgpr.num_inducing == n_inducing_points if jit_fns: elbo_fn = jax.jit(sgpr.elbo(D)) From 396e590716108a7de44c734d6f233012f0cb637b Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 14:56:16 +0000 Subject: [PATCH 04/15] Remove chex from docs. --- docs/README.md | 8 +--- examples/README.md | 85 ++++++++++++++++++++++++++++++++++++++ examples/kernels.pct.py | 23 +---------- examples/regression.pct.py | 4 +- gpjax/abstractions.py | 2 +- setup.py | 1 - 6 files changed, 90 insertions(+), 33 deletions(-) create mode 100644 examples/README.md diff --git a/docs/README.md b/docs/README.md index 986a3132..c06b6047 100644 --- a/docs/README.md +++ b/docs/README.md @@ -41,7 +41,6 @@ description and a code example. The docstring is concluded with a description of the objects attributes with corresponding types. ```python -@dataclass class Prior(AbstractPrior): """A Gaussian process prior object. The GP is parameterised by a `mean `_ @@ -78,9 +77,4 @@ class Prior(AbstractPrior): ### Documentation syntax A helpful cheatsheet for writing restructured text can be found -[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting -`dataclass` objects. - -* Class attributes should be specified using the `Attributes:` tag. -* Method argument should be specified using the `Args:` tags. -* All attributes and arguments should have types. +[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..36841b62 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,85 @@ +# Where to find the docs + +The GPJax documentation can be found here: +https://gpjax.readthedocs.io/en/latest/ + +# How to build the docs + +1. Install the requirements using `pip install -r docs/requirements.txt` +2. Make sure `pandoc` is installed +3. Run the make script `make html` + +The corresponding HTML files can then be found in `docs/_build/html/`. + +# How to write code documentation + +Our documentation it is written in ReStructuredText for Sphinx. This is a +meta-language that is compiled into online documentation. For more details see +[Sphinx's documentation](https://www.sphinx-doc.org/en/master/usage/restructuredtext/index.html). +As a result, our docstrings adhere to a specific syntax that has to be kept in +mind. Below we provide some guidelines. + +## How much information to put in a docstring + +A docstring should be informative. If in doubt, then it is best to add more +information to a docstring than less. Many users will skim documentation, so +please ensure the opening sentence or two of a docstring contains the core +information. Adding examples and mathematical descriptions to documentation is +highly desirable. + +We are making an active effort within GPJax to improve our documentation. If you +spot any areas where there is missing information within the existing +documentation, then please either raise an issue or +[create a pull request](https://gpjax.readthedocs.io/en/latest/contributing.html). + +## An example docstring + +An example docstring that adheres the principles of GPJax is given below. +The docstring contains a simple, snappy introduction with links to auxillary +components. More detail is then provided in the form of a mathematical +description and a code example. The docstring is concluded with a description +of the objects attributes with corresponding types. + +```python +class Prior(AbstractPrior): + """A Gaussian process prior object. The GP is parameterised by a + `mean `_ + and `kernel `_ function. + + A Gaussian process prior parameterised by a mean function :math:`m(\\cdot)` and a kernel + function :math:`k(\\cdot, \\cdot)` is given by + + .. math:: + + p(f(\\cdot)) = \mathcal{GP}(m(\\cdot), k(\\cdot, \\cdot)). + + To invoke a ``Prior`` distribution, only a kernel function is required. By default, + the mean function will be set to zero. In general, this assumption will be reasonable + assuming the data being modelled has been centred. + + Example: + >>> import gpjax as gpx + >>> + >>> kernel = gpx.kernels.RBF() + >>> prior = gpx.Prior(kernel = kernel) + + Attributes: + kernel (Kernel): The kernel function used to parameterise the prior. + mean_function (MeanFunction): The mean function used to parameterise the prior. Defaults to zero. + name (str): The name of the GP prior. Defaults to "GP prior". + """ + + kernel: Kernel + mean_function: Optional[AbstractMeanFunction] = Zero() + name: Optional[str] = "GP prior" +``` + +### Documentation syntax + +A helpful cheatsheet for writing restructured text can be found +[here](https://github.com/ralsina/rst-cheatsheet/blob/master/rst-cheatsheet.rst). In addition to that, we adopt the following convention when documenting +`` objects. + +* Class attributes should be specified using the `Attributes:` tag. +* Method argument should be specified using the `Args:` tags. +* All attributes and arguments should have types. diff --git a/examples/kernels.pct.py b/examples/kernels.pct.py index 5f9b73c6..c4408c6d 100644 --- a/examples/kernels.pct.py +++ b/examples/kernels.pct.py @@ -228,28 +228,7 @@ def _initialise_params(self, key: jr.PRNGKey) -> dict: # domain is a circle, this is $2\pi$. Next we define the kernel's `__call__` # function which is a direct implementation of Equation (1). Finally, we define # the Kernel's parameter property which contains just one value $\tau$ that we -# initialise to 4 in the kernel's `__post_init__`. -# -# #### Aside on dataclasses -# -# One can see in the above definition of a `Polar` kernel that we decorated the -# class with a `@dataclass` command. Dataclasses are simply regular classs -# objects in Python, however, much of the boilerplate code has been removed. For -# example, without a `@dataclass` decorator, the instantiation of the above -# `Polar` kernel would be done through -# ```python -# class Polar(jk.kernels.AbstractKernel): -# def __init__(self, period: float = 2*jnp.pi): -# super().__init__() -# self.period = period -# ``` -# As objects become increasingly large and complex, the conciseness of a -# dataclass becomes increasingly attractive. To ensure full compatability with -# Jax, it is crucial that the dataclass decorator is imported from Chex, not -# base Python's `dataclass` module. Functionally, the two objects are identical. -# However, unlike regular Python dataclasses, it is possilbe to apply operations -# such as `jit`, `vmap` and `grad` to the dataclasses given by Chex as they are -# registrered PyTrees. +# initialise to 4 in the kernel's `__init__`. # # # ### Custom Parameter Bijection diff --git a/examples/regression.pct.py b/examples/regression.pct.py index 708591a9..2b9233df 100644 --- a/examples/regression.pct.py +++ b/examples/regression.pct.py @@ -137,7 +137,7 @@ # # ## Parameter state # -# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` dataclass is returned that contains four dictionaries: +# So far, all of the objects that we've defined have been stateless. To give our model state, we can use the `initialise` function provided in GPJax. Upon calling this, a `ParameterState` class is returned that contains four dictionaries: # # | Dictionary | Description | # |---|---| @@ -189,7 +189,7 @@ ) # %% [markdown] -# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a dataclass, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation. +# Similar to the `ParameterState` object above, the returned variable from the `fit` function is a class, namely an `InferenceState` object that contains the parameters' final values and a tracked array of the evaluation of our objective function throughout optimisation. # %% learned_params, training_history = inference_state.unpack() diff --git a/gpjax/abstractions.py b/gpjax/abstractions.py index a99be71f..5a3352e6 100644 --- a/gpjax/abstractions.py +++ b/gpjax/abstractions.py @@ -252,7 +252,7 @@ def fit_natgrads( verbose (Optional[bool]): Whether to print the training loading bar. Defaults to True. Returns: - InferenceState: A dataclass comprising optimised parameters and training history. + InferenceState: A class comprising optimised parameters and training history. """ params, trainables, bijectors = parameter_state.unpack() diff --git a/setup.py b/setup.py index 53c5d678..02a28816 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ def get_versions(): "optax", "jaxutils", "jaxkern", - "chex", "distrax>=0.1.2", "tqdm>=4.0.0", "ml-collections==0.1.0", From 48ee78081902ee0b048e72f13920fb04297ef4b1 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 16:27:18 +0000 Subject: [PATCH 05/15] Improve tests. --- tests/test_gaussian_distribution.py | 10 ++++++++++ tests/test_gps.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/test_gaussian_distribution.py b/tests/test_gaussian_distribution.py index 8f3c0b63..46a67fb1 100644 --- a/tests/test_gaussian_distribution.py +++ b/tests/test_gaussian_distribution.py @@ -72,6 +72,9 @@ def test_diag_linear_operator(n: int) -> None: distrax_dist = MultivariateNormalDiag(loc=mean, scale_diag=diag) assert approx_equal(dist_diag.mean(), distrax_dist.mean()) + assert approx_equal(dist_diag.mode(), distrax_dist.mode()) + assert approx_equal(dist_diag.median(), distrax_dist.median()) + assert approx_equal(dist_diag.entropy(), distrax_dist.entropy()) assert approx_equal(dist_diag.variance(), distrax_dist.variance()) assert approx_equal(dist_diag.stddev(), distrax_dist.stddev()) assert approx_equal(dist_diag.covariance(), distrax_dist.covariance()) @@ -104,6 +107,9 @@ def test_dense_linear_operator(n: int) -> None: ) assert approx_equal(dist_dense.mean(), distrax_dist.mean()) + assert approx_equal(dist_dense.mode(), distrax_dist.mode()) + assert approx_equal(dist_dense.median(), distrax_dist.median()) + assert approx_equal(dist_dense.entropy(), distrax_dist.entropy()) assert approx_equal(dist_dense.variance(), distrax_dist.variance()) assert approx_equal(dist_dense.stddev(), distrax_dist.stddev()) assert approx_equal(dist_dense.covariance(), distrax_dist.covariance()) @@ -142,3 +148,7 @@ def test_kl_divergence(n: int) -> None: assert approx_equal( dist_a.kl_divergence(dist_b), distrax_dist_a.kl_divergence(distrax_dist_b) ) + + with pytest.raises(ValueError): + incompatible = GaussianDistribution(loc=jnp.ones((2 * n,))) + incompatible.kl_divergence(dist_a) diff --git a/tests/test_gps.py b/tests/test_gps.py index d7a4246d..9bab15d8 100644 --- a/tests/test_gps.py +++ b/tests/test_gps.py @@ -25,6 +25,7 @@ from gpjax import Dataset, initialise from gpjax.gps import ( AbstractPrior, + AbstractPosterior, ConjugatePosterior, NonConjugatePosterior, Prior, @@ -166,6 +167,24 @@ def test_param_construction(num_datapoints, lik): ] +@pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) +def test_abstract_posterior(lik): + pr = Prior(kernel=RBF()) + likelihood = lik(num_datapoints=10) + + with pytest.raises(TypeError): + _ = AbstractPosterior(pr, likelihood) + + class DummyPosterior(AbstractPosterior): + def predict(self): + pass + + dummy_post = DummyPosterior(pr, likelihood) + assert isinstance(dummy_post, AbstractPosterior) + assert dummy_post.likelihood == likelihood + assert dummy_post.prior == pr + + @pytest.mark.parametrize("lik", [Bernoulli, Gaussian]) def test_posterior_construct(lik): pr = Prior(kernel=RBF()) From a74823dd2bdb1cba8ae924bab3b6be9443e7949e Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 16:46:06 +0000 Subject: [PATCH 06/15] Add kernel/type tests back to GPJax. --- gpjax/kernels.py | 13 +- tests/test_kernels.py | 596 ++++++++++++++++++++++++++++++++++++++++++ tests/test_types.py | 89 +++++++ 3 files changed, 693 insertions(+), 5 deletions(-) create mode 100644 tests/test_kernels.py create mode 100644 tests/test_types.py diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 2956e501..03b93280 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -17,6 +17,14 @@ import deprecation +# These abstract types will also be removed in v0.6.0 +AbstractKernel = jaxkern.kernels.AbstractKernel +AbstractKernelComputation = jaxkern.kernels.AbstractKernelComputation +CombinationKernel = jaxkern.kernels.CombinationKernel +SumKernel = jaxkern.kernels.SumKernel +ProductKernel = jaxkern.kernels.ProductKernel + +# Import kernels/functions from `JaxKern`` and wrap them in a deprecation. def deprecate(cls): return deprecation.deprecated( deprecated_in="0.5.5", @@ -25,15 +33,10 @@ def deprecate(cls): )(cls) -AbstractKernelComputation = deprecate(jaxkern.kernels.AbstractKernelComputation) DiagonalKernelComputation = deprecate(jaxkern.kernels.DiagonalKernelComputation) ConstantDiagonalKernelComputation = deprecate( jaxkern.kernels.ConstantDiagonalKernelComputation ) -AbstractKernel = deprecate(jaxkern.kernels.AbstractKernel) -CombinationKernel = deprecate(jaxkern.kernels.CombinationKernel) -SumKernel = deprecate(jaxkern.kernels.SumKernel) -ProductKernel = deprecate(jaxkern.kernels.ProductKernel) RBF = deprecate(jaxkern.kernels.RBF) Matern12 = deprecate(jaxkern.kernels.Matern12) Matern32 = deprecate(jaxkern.kernels.Matern32) diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100644 index 00000000..df4c8aa0 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,596 @@ +# Copyright 2022 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from itertools import permutations +from typing import Dict, List + +import jax +import jax.numpy as jnp +import jax.random as jr +import networkx as nx +import pytest +from gpjax.parameters import initialise +from jax.config import config +from jax.random import KeyArray as PRNGKeyType +from jaxlinop import LinearOperator, identity +from jaxtyping import Array, Float + +from gpjax.kernels import ( + RBF, + AbstractKernel, + CombinationKernel, + GraphKernel, + Linear, + Matern12, + Matern32, + Matern52, + Periodic, + Polynomial, + PoweredExponential, + ProductKernel, + RationalQuadratic, + SumKernel, + euclidean_distance, +) + +# Enable Float64 for more stable matrix inversions. +config.update("jax_enable_x64", True) +_initialise_key = jr.PRNGKey(123) +_jitter = 1e-6 + + +def test_abstract_kernel(): + # Test initialising abstract kernel raises TypeError with unimplemented __call__ and _init_params methods: + with pytest.raises(TypeError): + AbstractKernel() + + # Create a dummy kernel class with __call__ and _init_params methods implemented: + class DummyKernel(AbstractKernel): + def __call__( + self, x: Float[Array, "1 D"], y: Float[Array, "1 D"], params: Dict + ) -> Float[Array, "1"]: + return x * params["test"] * y + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return {"test": 1.0} + + # Initialise dummy kernel class and test __call__ and _init_params methods: + dummy_kernel = DummyKernel() + assert dummy_kernel._initialise_params(_initialise_key) == {"test": 1.0} + assert dummy_kernel(jnp.array([1.0]), jnp.array([2.0]), {"test": 2.0}) == 4.0 + + +@pytest.mark.parametrize( + "a, b, distance_to_3dp", + [ + ([1.0], [-4.0], 5.0), + ([1.0, -2.0], [-4.0, 3.0], 7.071), + ([1.0, 2.0, 3.0], [1.0, 1.0, 1.0], 2.236), + ], +) +def test_euclidean_distance( + a: List[float], b: List[float], distance_to_3dp: float +) -> None: + + # Convert lists to JAX arrays: + a: Float[Array, "D"] = jnp.array(a) + b: Float[Array, "D"] = jnp.array(b) + + # Test distance is correct to 3dp: + assert jnp.round(euclidean_distance(a, b), 3) == distance_to_3dp + + +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("n", [1, 2, 10]) +def test_gram(kernel: AbstractKernel, dim: int, n: int) -> None: + + # Gram constructor static method: + kernel.gram + + # Inputs x: + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Default kernel parameters: + params = kernel._initialise_params(_initialise_key) + + # Test gram matrix: + Kxx = kernel.gram(params, x) + assert isinstance(Kxx, LinearOperator) + assert Kxx.shape == (n, n) + + +@pytest.mark.parametrize( + "kernel", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize("num_a", [1, 2, 5]) +@pytest.mark.parametrize("num_b", [1, 2, 5]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +def test_cross_covariance( + kernel: AbstractKernel, num_a: int, num_b: int, dim: int +) -> None: + # Inputs a, b: + a = jnp.linspace(-1.0, 1.0, num_a * dim).reshape(num_a, dim) + b = jnp.linspace(3.0, 4.0, num_b * dim).reshape(num_b, dim) + + # Default kernel parameters: + params = kernel._initialise_params(_initialise_key) + + # Test cross covariance, Kab: + Kab = kernel.cross_covariance(params, a, b) + assert isinstance(Kab, jnp.ndarray) + assert Kab.shape == (num_a, num_b) + + +@pytest.mark.parametrize("kernel", [RBF(), Matern12(), Matern32(), Matern52()]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +def test_call(kernel: AbstractKernel, dim: int) -> None: + + # Datapoint x and datapoint y: + x = jnp.array([[1.0] * dim]) + y = jnp.array([[0.5] * dim]) + + # Defualt parameters: + params = kernel._initialise_params(_initialise_key) + + # Test calling gives an autocovariance value of no dimension between the inputs: + kxy = kernel(params, x, y) + + assert isinstance(kxy, jax.Array) + assert kxy.shape == () + + +@pytest.mark.parametrize("kern", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def( + kern: AbstractKernel, dim: int, ell: float, sigma: float, n: int +) -> None: + kern = kern(active_dims=list(range(dim))) + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = {"lengthscale": jnp.array([ell]), "variance": jnp.array([sigma])} + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("kern", [Linear, Polynomial]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("shift", [0.0, 0.5, 2.0]) +@pytest.mark.parametrize("sigma", [0.1, 0.2, 0.5]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_lin_poly( + kern: AbstractKernel, dim: int, shift: float, sigma: float, n: int +) -> None: + kern = kern(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = {"variance": jnp.array([sigma]), "shift": jnp.array([shift])} + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +@pytest.mark.parametrize("alpha", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_rq(dim: int, ell: float, sigma: float, alpha: float, n: int) -> None: + kern = RationalQuadratic(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "alpha": jnp.array([alpha]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +@pytest.mark.parametrize("power", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_power_exp( + dim: int, ell: float, sigma: float, power: float, n: int +) -> None: + kern = PoweredExponential(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "power": jnp.array([power]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("ell, sigma", [(0.1, 0.2), (0.5, 0.1), (0.1, 0.5), (0.5, 0.5)]) +@pytest.mark.parametrize("period", [0.1, 0.5, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_pos_def_periodic( + dim: int, ell: float, sigma: float, period: float, n: int +) -> None: + kern = Periodic(active_dims=list(range(dim))) + # Gram constructor static method: + kern.gram + + # Create inputs x: + x = jr.uniform(_initialise_key, (n, dim)) + params = { + "lengthscale": jnp.array([ell]), + "variance": jnp.array([sigma]), + "period": jnp.array([period]), + } + + # Test gram matrix eigenvalues are positive: + Kxx = kern.gram(params, x) + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0.0).all() + + +@pytest.mark.parametrize("kernel", [RBF, Matern12, Matern32, Matern52]) +@pytest.mark.parametrize("dim", [None, 1, 2, 5, 10]) +def test_initialisation(kernel: AbstractKernel, dim: int) -> None: + + if dim is None: + kern = kernel() + assert kern.ndims == 1 + + else: + kern = kernel(active_dims=[i for i in range(dim)]) + params = kern._initialise_params(_initialise_key) + + assert list(params.keys()) == ["lengthscale", "variance"] + assert all(params["lengthscale"] == jnp.array([1.0] * dim)) + assert params["variance"] == jnp.array([1.0]) + + if dim > 1: + assert kern.ard + else: + assert not kern.ard + + +@pytest.mark.parametrize( + "kernel", + [ + RBF, + Matern12, + Matern32, + Matern52, + Linear, + Polynomial, + RationalQuadratic, + PoweredExponential, + Periodic, + ], +) +def test_dtype(kernel: AbstractKernel) -> None: + parameter_state = initialise(kernel(), _initialise_key) + params, *_ = parameter_state.unpack() + for k, v in params.items(): + assert v.dtype == jnp.float64 + assert isinstance(k, str) + + +@pytest.mark.parametrize("degree", [1, 2, 3]) +@pytest.mark.parametrize("dim", [1, 2, 5]) +@pytest.mark.parametrize("variance", [0.1, 1.0, 2.0]) +@pytest.mark.parametrize("shift", [1e-6, 0.1, 1.0]) +@pytest.mark.parametrize("n", [1, 2, 5]) +def test_polynomial( + degree: int, dim: int, variance: float, shift: float, n: int +) -> None: + + # Define inputs + x = jnp.linspace(0.0, 1.0, n * dim).reshape(n, dim) + + # Define kernel + kern = Polynomial(degree=degree, active_dims=[i for i in range(dim)]) + + # Check name + assert kern.name == f"Polynomial Degree: {degree}" + + # Initialise parameters + params = kern._initialise_params(_initialise_key) + params["shift"] * shift + params["variance"] * variance + + # Check parameter keys + assert list(params.keys()) == ["shift", "variance"] + + # Compute gram matrix + Kxx = kern.gram(params, x) + + # Check shapes + assert Kxx.shape[0] == x.shape[0] + assert Kxx.shape[0] == Kxx.shape[1] + + # Test positive definiteness + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() + + +@pytest.mark.parametrize( + "kernel", + [RBF, Matern12, Matern32, Matern52, Linear, Polynomial, RationalQuadratic], +) +def test_active_dim(kernel: AbstractKernel) -> None: + dim_list = [0, 1, 2, 3] + perm_length = 2 + dim_pairs = list(permutations(dim_list, r=perm_length)) + n_dims = len(dim_list) + + # Generate random inputs + x = jr.normal(_initialise_key, shape=(20, n_dims)) + + for dp in dim_pairs: + # Take slice of x + slice = x[..., dp] + + # Define kernels + ad_kern = kernel(active_dims=dp) + manual_kern = kernel(active_dims=[i for i in range(perm_length)]) + + # Get initial parameters + ad_params = ad_kern._initialise_params(_initialise_key) + manual_params = manual_kern._initialise_params(_initialise_key) + + # Compute gram matrices + ad_Kxx = ad_kern.gram(ad_params, x) + manual_Kxx = manual_kern.gram(manual_params, slice) + + # Test gram matrices are equal + assert jnp.all(ad_Kxx.to_dense() == manual_Kxx.to_dense()) + + +@pytest.mark.parametrize("combination_type", [SumKernel, ProductKernel]) +@pytest.mark.parametrize( + "kernel", + [RBF, RationalQuadratic, Linear, Matern12, Matern32, Matern52, Polynomial], +) +@pytest.mark.parametrize("n_kerns", [2, 3, 4]) +def test_combination_kernel( + combination_type: CombinationKernel, kernel: AbstractKernel, n_kerns: int +) -> None: + + # Create inputs + n = 20 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create list of kernels + kernel_set = [kernel() for _ in range(n_kerns)] + + # Create combination kernel + combination_kernel = combination_type(kernel_set=kernel_set) + + # Initialise default parameters + params = combination_kernel._initialise_params(_initialise_key) + + # Check params are a list of dictionaries + assert len(params) == n_kerns + + for p in params: + assert isinstance(p, dict) + + # Check combination kernel set + assert len(combination_kernel.kernel_set) == n_kerns + assert isinstance(combination_kernel.kernel_set, list) + assert isinstance(combination_kernel.kernel_set[0], AbstractKernel) + + # Compute gram matrix + Kxx = combination_kernel.gram(params, x) + + # Check shapes + assert Kxx.shape[0] == Kxx.shape[1] + assert Kxx.shape[1] == n + + # Check positive definiteness + Kxx += identity(n) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert (eigen_values > 0).all() + + +@pytest.mark.parametrize( + "k1", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] +) +@pytest.mark.parametrize( + "k2", [RBF(), Matern12(), Matern32(), Matern52(), Polynomial()] +) +def test_sum_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + # Create inputs + n = 10 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create sum kernel + sum_kernel = SumKernel(kernel_set=[k1, k2]) + + # Initialise default parameters + params = sum_kernel._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx = sum_kernel.gram(params, x) + + # NOW we do the same thing manually and check they are equal: + # Initialise default parameters + k1_params = k1._initialise_params(_initialise_key) + k2_params = k2._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1.gram(k1_params, x) + Kxx_k2 = k2.gram(k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() + Kxx_k2.to_dense()) + + +@pytest.mark.parametrize( + "k1", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Polynomial(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +@pytest.mark.parametrize( + "k2", + [ + RBF(), + Matern12(), + Matern32(), + Matern52(), + Polynomial(), + Linear(), + Polynomial(), + RationalQuadratic(), + ], +) +def test_prod_kern_value(k1: AbstractKernel, k2: AbstractKernel) -> None: + + # Create inputs + n = 10 + x = jnp.linspace(0.0, 1.0, num=n).reshape(-1, 1) + + # Create product kernel + prod_kernel = ProductKernel(kernel_set=[k1, k2]) + + # Initialise default parameters + params = prod_kernel._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx = prod_kernel.gram(params, x) + + # NOW we do the same thing manually and check they are equal: + + # Initialise default parameters + k1_params = k1._initialise_params(_initialise_key) + k2_params = k2._initialise_params(_initialise_key) + + # Compute gram matrix + Kxx_k1 = k1.gram(k1_params, x) + Kxx_k2 = k2.gram(k2_params, x) + + # Check manual and automatic gram matrices are equal + assert jnp.all(Kxx.to_dense() == Kxx_k1.to_dense() * Kxx_k2.to_dense()) + + +def test_graph_kernel(): + # Create a random graph, G, and verice labels, x, + n_verticies = 20 + n_edges = 40 + G = nx.gnm_random_graph(n_verticies, n_edges, seed=123) + x = jnp.arange(n_verticies).reshape(-1, 1) + + # Compute graph laplacian + L = nx.laplacian_matrix(G).toarray() + jnp.eye(n_verticies) * 1e-12 + + # Create graph kernel + kern = GraphKernel(laplacian=L) + assert isinstance(kern, GraphKernel) + assert kern.num_vertex == n_verticies + assert kern.evals.shape == (n_verticies, 1) + assert kern.evecs.shape == (n_verticies, n_verticies) + + # Unpack kernel computation + kern.gram + + # Initialise default parameters + params = kern._initialise_params(_initialise_key) + assert isinstance(params, dict) + assert list(sorted(list(params.keys()))) == [ + "lengthscale", + "smoothness", + "variance", + ] + + # Compute gram matrix + Kxx = kern.gram(params, x) + assert Kxx.shape == (n_verticies, n_verticies) + + # Check positive definiteness + Kxx += identity(n_verticies) * _jitter + eigen_values = jnp.linalg.eigvalsh(Kxx.to_dense()) + assert all(eigen_values > 0) + + +@pytest.mark.parametrize( + "kernel", + [RBF, Matern12, Matern32, Matern52, Polynomial, Linear, RationalQuadratic], +) +def test_combination_kernel_type(kernel: AbstractKernel) -> None: + prod_kern = kernel() * kernel() + assert isinstance(prod_kern, ProductKernel) + assert isinstance(prod_kern, CombinationKernel) + + add_kern = kernel() + kernel() + assert isinstance(add_kern, SumKernel) + assert isinstance(add_kern, CombinationKernel) diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 00000000..f08f66bf --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,89 @@ +# Copyright 2022 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import jax.numpy as jnp +import pytest +from gpjax.types import Dataset, verify_dataset + + +@pytest.mark.parametrize("n", [1, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +@pytest.mark.parametrize("n2", [1, 10]) +def test_dataset(n: int, outd: int, ind: int, n2: int) -> None: + x = jnp.ones((n, ind)) + y = jnp.ones((n, outd)) + d = Dataset(X=x, y=y) + + verify_dataset(d) + assert d.n == n + assert d.in_dim == ind + assert d.out_dim == outd + + assert d.__repr__() == f"- Number of datapoints: {n}\n- Dimension: {ind}" + + # Test combine datasets. + x2 = 2 * jnp.ones((n2, ind)) + y2 = 2 * jnp.ones((n2, outd)) + d2 = Dataset(X=x2, y=y2) + + d_combined = d + d2 + assert d_combined.n == n + n2 + assert d_combined.in_dim == ind + assert d_combined.out_dim == outd + assert (d_combined.y[:n] == 1.0).all() + assert (d_combined.y[n:] == 2.0).all() + assert (d_combined.X[:n] == 1.0).all() + assert (d_combined.X[n:] == 2.0).all() + + # Test supervised and unsupervised. + assert d.is_supervised() is True + dunsup = Dataset(y=y) + assert dunsup.is_unsupervised() is True + + +@pytest.mark.parametrize("nx, ny", [(1, 2), (2, 1), (10, 5), (5, 10)]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_dataset_assertions(nx: int, ny: int, outd: int, ind: int) -> None: + x = jnp.ones((nx, ind)) + y = jnp.ones((ny, outd)) + + with pytest.raises(ValueError): + ds = Dataset(X=x, y=y) + + +@pytest.mark.parametrize("n", [1, 2, 10]) +@pytest.mark.parametrize("outd", [1, 2, 10]) +@pytest.mark.parametrize("ind", [1, 2, 10]) +def test_2d_inputs(n: int, outd: int, ind: int) -> None: + x = jnp.ones((n, ind)) + y = jnp.ones((n,)) + + with pytest.raises(ValueError): + ds = Dataset(X=x, y=y) + + x = jnp.ones((n,)) + y = jnp.ones((n, outd)) + + with pytest.raises(ValueError): + ds = Dataset(X=x, y=y) + + +def test_y_none() -> None: + x = jnp.ones((10, 1)) + d = Dataset(X=x) + verify_dataset(d) + assert d.y is None From f63e7db2025b36ec241ee051cc90a77c49e39d35 Mon Sep 17 00:00:00 2001 From: Daniel Dodd Date: Thu, 5 Jan 2023 17:14:36 +0000 Subject: [PATCH 07/15] Update kernels.py --- gpjax/kernels.py | 1126 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 1090 insertions(+), 36 deletions(-) diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 03b93280..7800d1b1 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -13,45 +13,1099 @@ # limitations under the License. # ============================================================================== -import jaxkern +import abc +from typing import Callable, Dict, List, Optional, Sequence + +from jaxlinop import ( + LinearOperator, + DenseLinearOperator, + DiagonalLinearOperator, + ConstantDiagonalLinearOperator, +) + +import jax.numpy as jnp +from jax import vmap +import jax +from jaxtyping import Array, Float + +from chex import PRNGKey as PRNGKeyType +from jaxutils import PyTree import deprecation -# These abstract types will also be removed in v0.6.0 -AbstractKernel = jaxkern.kernels.AbstractKernel -AbstractKernelComputation = jaxkern.kernels.AbstractKernelComputation -CombinationKernel = jaxkern.kernels.CombinationKernel -SumKernel = jaxkern.kernels.SumKernel -ProductKernel = jaxkern.kernels.ProductKernel - -# Import kernels/functions from `JaxKern`` and wrap them in a deprecation. -def deprecate(cls): - return deprecation.deprecated( - deprecated_in="0.5.5", - removed_in="0.6.0", - details="Use JaxKern for the " + cls.__name__, - )(cls) - - -DiagonalKernelComputation = deprecate(jaxkern.kernels.DiagonalKernelComputation) -ConstantDiagonalKernelComputation = deprecate( - jaxkern.kernels.ConstantDiagonalKernelComputation -) -RBF = deprecate(jaxkern.kernels.RBF) -Matern12 = deprecate(jaxkern.kernels.Matern12) -Matern32 = deprecate(jaxkern.kernels.Matern32) -Matern52 = deprecate(jaxkern.kernels.Matern52) -Linear = deprecate(jaxkern.kernels.Linear) -Periodic = deprecate(jaxkern.kernels.Periodic) -White = deprecate(jaxkern.kernels.White) -PoweredExponential = deprecate(jaxkern.kernels.PoweredExponential) -RationalQuadratic = deprecate(jaxkern.kernels.RationalQuadratic) -Polynomial = deprecate(jaxkern.kernels.Polynomial) -EigenKernelComputation = deprecate(jaxkern.kernels.EigenKernelComputation) -GraphKernel = deprecate(jaxkern.kernels.GraphKernel) -squared_distance = deprecate(jaxkern.kernels.squared_distance) -euclidean_distance = deprecate(jaxkern.kernels.euclidean_distance) -jax_gather_nd = deprecate(jaxkern.kernels.jax_gather_nd) +class AbstractKernelComputation(PyTree): + """Abstract class for kernel computations.""" + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + self._kernel_fn = kernel_fn + + @property + def kernel_fn( + self, + ) -> Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array]: + return self._kernel_fn + + @kernel_fn.setter + def kernel_fn( + self, + kernel_fn: Callable[[Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array], + ) -> None: + self._kernel_fn = kernel_fn + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> LinearOperator: + + """Compute Gram covariance operator of the kernel function. + + Args: + kernel (AbstractKernel): The kernel function to be evaluated. + params (Dict): The parameters of the kernel function. + inputs (Float[Array, "N N"]): The inputs to the kernel function. + + Returns: + LinearOperator: Gram covariance operator of the kernel function. + """ + + matrix = self.cross_covariance(params, inputs, inputs) + + return DenseLinearOperator(matrix=matrix) + + @abc.abstractmethod + def cross_covariance( + self, + params: Dict, + x: Float[Array, "N D"], + y: Float[Array, "M D"], + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM gram matrix on an a pair + of input matrices with shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the cross-covariance + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The first input matrix. + y (Float[Array,"M D"]): The second input matrix. + + Returns: + Float[Array, "N M"]: The computed square Gram matrix. + """ + raise NotImplementedError + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + diag = vmap(lambda x: self._kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the DenseKernelComputation", +) +class DenseKernelComputation(AbstractKernelComputation): + """Dense kernel computation class. Operations with the kernel assume + a dense gram matrix structure. + """ + + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + """For a given kernel, compute the NxM covariance matrix on a pair of input + matrices of shape NxD and MxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram + matrix should be computed for. + params (Dict): The kernel's parameter set. + x (Float[Array,"N D"]): The input matrix. + y (Float[Array,"M D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + cross_cov = vmap(lambda x: vmap(lambda y: self.kernel_fn(params, x, y))(y))(x) + return cross_cov + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the DiagonalKernelComputation", +) +class DiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for diagonal kernels.") + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the ConstantDiagonalKernelComputation", +) +class ConstantDiagonalKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + + def gram( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> ConstantDiagonalLinearOperator: + """For a kernel with diagonal structure, compute the NxN gram matrix on + an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the Gram matrix + should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + CovarianceOperator: The computed square Gram matrix. + """ + + value = self.kernel_fn(params, inputs[0], inputs[0]) + + return ConstantDiagonalLinearOperator(value=value, size=inputs.shape[0]) + + def diagonal( + self, + params: Dict, + inputs: Float[Array, "N D"], + ) -> DiagonalLinearOperator: + """For a given kernel, compute the elementwise diagonal of the + NxN gram matrix on an input matrix of shape NxD. + + Args: + kernel (AbstractKernel): The kernel for which the variance + vector should be computed for. + params (Dict): The kernel's parameter set. + inputs (Float[Array, "N D"]): The input matrix. + + Returns: + LinearOperator: The computed diagonal variance entries. + """ + + diag = vmap(lambda x: self.kernel_fn(params, x, x))(inputs) + + return DiagonalLinearOperator(diag=diag) + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + raise ValueError("Cross covariance not defined for constant diagonal kernels.") + + +########################################## +# Abtract classes +########################################## +class AbstractKernel(PyTree): + """ + Base kernel class""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "AbstractKernel", + ) -> None: + self.compute_engine = compute_engine + self.active_dims = active_dims + self.stationary = stationary + self.spectral = spectral + self.name = name + self.ndims = 1 if not self.active_dims else len(self.active_dims) + compute_engine = self.compute_engine(kernel_fn=self.__call__) + self.gram = compute_engine.gram + self.cross_covariance = compute_engine.cross_covariance + + @abc.abstractmethod + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + raise NotImplementedError + + def slice_input(self, x: Float[Array, "N D"]) -> Float[Array, "N Q"]: + """Select the relevant columns of the supplied matrix to be used within the kernel's evaluation. + + Args: + x (Float[Array, "N D"]): The matrix or vector that is to be sliced. + Returns: + Float[Array, "N Q"]: A sliced form of the input matrix. + """ + return x[..., self.active_dims] + + def __add__(self, other: "AbstractKernel") -> "AbstractKernel": + """Add two kernels together. + Args: + other (AbstractKernel): The kernel to be added to the current kernel. + + Returns: + AbstractKernel: A new kernel that is the sum of the two kernels. + """ + return SumKernel(kernel_set=[self, other]) + + def __mul__(self, other: "AbstractKernel") -> "AbstractKernel": + """Multiply two kernels together. + + Args: + other (AbstractKernel): The kernel to be multiplied with the current kernel. + + Returns: + AbstractKernel: A new kernel that is the product of the two kernels. + """ + return ProductKernel(kernel_set=[self, other]) + + @property + def ard(self): + """Boolean property as to whether the kernel is isotropic or of + automatic relevance determination form. + + Returns: + bool: True if the kernel is an ARD kernel. + """ + return True if self.ndims > 1 else False + + @abc.abstractmethod + def _initialise_params(self, key: PRNGKeyType) -> Dict: + """A template dictionary of the kernel's parameter set. + + Args: + key (PRNGKeyType): A PRNG key to be used for initialising + the kernel's parameters. + + Returns: + Dict: A dictionary of the kernel's parameters. + """ + raise NotImplementedError + + +class CombinationKernel(AbstractKernel): + """A base class for products or sums of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "AbstractKernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + self.kernel_set = kernel_set + name: Optional[str] = "Combination kernel" + self.combination_fn: Optional[Callable] = None + + if not all(isinstance(k, AbstractKernel) for k in self.kernel_set): + raise TypeError("can only combine Kernel instances") # pragma: no cover + + self._set_kernels(self.kernel_set) + + def _set_kernels(self, kernels: Sequence[AbstractKernel]) -> None: + """Combine multiple kernels. Based on GPFlow's Combination kernel.""" + # add kernels to a list, flattening out instances of this class therein + kernels_list: List[AbstractKernel] = [] + for k in kernels: + if isinstance(k, self.__class__): + kernels_list.extend(k.kernel_set) + else: + kernels_list.append(k) + + self.kernel_set = kernels_list + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + """A template dictionary of the kernel's parameter set.""" + return [kernel._initialise_params(key) for kernel in self.kernel_set] + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate combination kernel on a pair of inputs. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + return self.combination_fn( + jnp.stack([k(p, x, y) for k, p in zip(self.kernel_set, params)]) + ) + + +class SumKernel(CombinationKernel): + """A kernel that is the sum of a set of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Sum kernel", + ) -> None: + super().__init__( + kernel_set, compute_engine, active_dims, stationary, spectral, name + ) + self.combination_fn: Optional[Callable] = jnp.sum + + +class ProductKernel(CombinationKernel): + """A kernel that is the product of a set of kernels.""" + + def __init__( + self, + kernel_set: List[AbstractKernel], + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Product kernel", + ) -> None: + super().__init__( + kernel_set, compute_engine, active_dims, stationary, spectral, name + ) + self.combination_fn: Optional[Callable] = jnp.prod + + +########################################## +# Euclidean kernels +########################################## +@deprecation.deprecated( + deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the RBF kernel" +) +class RBF(AbstractKernel): + """The Radial Basis Function (RBF) kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Radial basis function kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( \\frac{\\lVert x - y \\rVert^2_2}{2 \\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-0.5 * squared_distance(x, y)) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + params = { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + return jax.tree_util.tree_map(lambda x: jnp.atleast_1d(x), params) + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Matern12 kernel", +) +class Matern12(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 0.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matérn 1/2 kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -\\frac{\\lvert x-y \\rvert}{2\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-euclidean_distance(x, y)) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Matern32 kernel", +) +class Matern32(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 1.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 3/2", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{3}\\lvert x-y \\rvert}{\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{3}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(3.0) * tau) + * jnp.exp(-jnp.sqrt(3.0) * tau) + ) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Matern52 kernel", +) +class Matern52(AbstractKernel): + """The Matérn kernel with smoothness parameter fixed at 2.5.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Matern 5/2", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with + lengthscale parameter :math:`\ell` and variance :math:`\sigma^2` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg(1+ \\frac{\\sqrt{5}\\lvert x-y \\rvert}{\\ell^2} + \\frac{5\\lvert x - y \\rvert^2}{3\\ell^2} \\Bigg)\\exp\\Bigg(-\\frac{\\sqrt{5}\\lvert x-y\\rvert}{\\ell^2} \\Bigg) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + tau = euclidean_distance(x, y) + K = ( + params["variance"] + * (1.0 + jnp.sqrt(5.0) * tau + 5.0 / 3.0 * jnp.square(tau)) + * jnp.exp(-jnp.sqrt(5.0) * tau) + ) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the PoweredExponential kernel", +) +class PoweredExponential(AbstractKernel): + """The powered exponential family of kernels. + + Key reference is Diggle and Ribeiro (2007) - "Model-based Geostatistics". + + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Powered exponential", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell`, :math:`\sigma` and power :math:`\kappa`. + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( - \\Big( \\frac{\\lVert x - y \\rVert^2}{\\ell^2} \\Big)^\\kappa \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * jnp.exp(-euclidean_distance(x, y) ** params["power"]) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "power": jnp.array([1.0]), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Linear kernel", +) +class Linear(AbstractKernel): + """The linear kernel.""" + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Linear", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance parameter :math:`\sigma` + + .. math:: + k(x, y) = \\sigma^2 x^{T}y + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + K = params["variance"] * jnp.matmul(x.T, y) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return {"variance": jnp.array([1.0])} + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Polynomial kernel", +) +class Polynomial(AbstractKernel): + """The Polynomial kernel with variable degree.""" + + def __init__( + self, + degree: int = 1, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Polynomial", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + self.degree = degree + self.name = f"Polynomial Degree: {self.degree}" + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with shift parameter :math:`\\alpha` and variance :math:`\sigma^2` through + + .. math:: + k(x, y) = \\Big( \\alpha + \\sigma^2 xy \\Big)^{d} + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + x = self.slice_input(x).squeeze() + y = self.slice_input(y).squeeze() + K = jnp.power(params["shift"] + jnp.dot(x * params["variance"], y), self.degree) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "shift": jnp.array([1.0]), + "variance": jnp.array([1.0] * self.ndims), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the White kernel", +) +class White(AbstractKernel): + def __init__( + self, + compute_engine: AbstractKernelComputation = ConstantDiagonalKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "White", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __post_init__(self) -> None: + super(White, self).__post_init__() + + def __call__( + self, params: Dict, x: Float[Array, "1 D"], y: Float[Array, "1 D"] + ) -> Float[Array, "1"]: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with variance :math:`\sigma` + + .. math:: + k(x, y) = \\sigma^2 \delta(x-y) + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): The left hand argument of the kernel function's call. + y (Float[Array, "1 D"]): The right hand argument of the kernel function's call. + + Returns: + Float[Array, "1"]: The value of :math:`k(x, y)`. + """ + K = jnp.all(jnp.equal(x, y)) * params["variance"] + return K.squeeze() + + def _initialise_params(self, key: Float[Array, "1 D"]) -> Dict: + """Initialise the kernel parameters. + + Args: + key (Float[Array, "1 D"]): The key to initialise the parameters with. + + Returns: + Dict: The initialised parameters. + """ + return {"variance": jnp.array([1.0])} + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the RationalQuadratic kernel", +) +class RationalQuadratic(AbstractKernel): + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Rational Quadratic", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( 1 + \\frac{\\lVert x - y \\rVert^2_2}{2 \\alpha \\ell^2} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) / params["lengthscale"] + y = self.slice_input(y) / params["lengthscale"] + K = params["variance"] * ( + 1 + 0.5 * squared_distance(x, y) / params["alpha"] + ) ** (-params["alpha"]) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "alpha": jnp.array([1.0]), + } + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the Periodic kernel", +) +class Periodic(AbstractKernel): + """The periodic kernel. + + Key reference is MacKay 1998 - "Introduction to Gaussian processes". + """ + + def __init__( + self, + compute_engine: AbstractKernelComputation = DenseKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Periodic", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + + def __call__(self, params: dict, x: jax.Array, y: jax.Array) -> Array: + """Evaluate the kernel on a pair of inputs :math:`(x, y)` with length-scale parameter :math:`\ell` and variance :math:`\sigma` + + .. math:: + k(x, y) = \\sigma^2 \\exp \\Bigg( -0.5 \\sum_{i=1}^{d} \\Bigg) + + Args: + x (jax.Array): The left hand argument of the kernel function's call. + y (jax.Array): The right hand argument of the kernel function's call + params (dict): Parameter set for which the kernel should be evaluated on. + Returns: + Array: The value of :math:`k(x, y)` + """ + x = self.slice_input(x) + y = self.slice_input(y) + sine_squared = ( + jnp.sin(jnp.pi * (x - y) / params["period"]) / params["lengthscale"] + ) ** 2 + K = params["variance"] * jnp.exp(-0.5 * jnp.sum(sine_squared, axis=0)) + return K.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "period": jnp.array([1.0] * self.ndims), + } + + +########################################## +# Graph kernels +########################################## +class EigenKernelComputation(AbstractKernelComputation): + def __init__( + self, + kernel_fn: Callable[ + [Dict, Float[Array, "1 D"], Float[Array, "1 D"]], Array + ] = None, + ) -> None: + super().__init__(kernel_fn) + self._eigenvalues = None + self._eigenvectors = None + self._num_verticies = None + + # Define an eigenvalue setter and getter property + @property + def eigensystem(self) -> Float[Array, "N"]: + return self._eigenvalues, self._eigenvectors, self._num_verticies + + @eigensystem.setter + def eigensystem( + self, eigenvalues: Float[Array, "N"], eigenvectors: Float[Array, "N N"] + ) -> None: + self._eigenvalues = eigenvalues + self._eigenvectors = eigenvectors + + @property + def num_vertex(self) -> int: + return self._num_verticies + + @num_vertex.setter + def num_vertex(self, num_vertex: int) -> None: + self._num_verticies = num_vertex + + def _compute_S(self, params): + evals, evecs = self.eigensystem + S = jnp.power( + evals + + 2 * params["smoothness"] / params["lengthscale"] / params["lengthscale"], + -params["smoothness"], + ) + S = jnp.multiply(S, self.num_vertex / jnp.sum(S)) + S = jnp.multiply(S, params["variance"]) + return S + + def cross_covariance( + self, params: Dict, x: Float[Array, "N D"], y: Float[Array, "M D"] + ) -> Float[Array, "N M"]: + S = self._compute_S(params=params) + matrix = self.kernel_fn(params, x, y, S=S) + return matrix + + +@deprecation.deprecated( + deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" +) +class GraphKernel(AbstractKernel): + def __init__( + self, + laplacian: Float[Array, "N N"], + compute_engine: EigenKernelComputation = EigenKernelComputation, + active_dims: Optional[List[int]] = None, + stationary: Optional[bool] = False, + spectral: Optional[bool] = False, + name: Optional[str] = "Graph kernel", + ) -> None: + super().__init__(compute_engine, active_dims, stationary, spectral, name) + self.laplacian = laplacian + evals, self.evecs = jnp.linalg.eigh(self.laplacian) + self.evals = evals.reshape(-1, 1) + self.compute_engine.eigensystem = self.evals, self.evecs + self.compute_engine.num_vertex = self.laplacian.shape[0] + + def __call__( + self, + params: Dict, + x: Float[Array, "1 D"], + y: Float[Array, "1 D"], + **kwargs, + ) -> Float[Array, "1"]: + """Evaluate the graph kernel on a pair of vertices :math:`v_i, v_j`. + + Args: + params (Dict): Parameter set for which the kernel should be evaluated on. + x (Float[Array, "1 D"]): Index of the ith vertex. + y (Float[Array, "1 D"]): Index of the jth vertex. + + Returns: + Float[Array, "1"]: The value of :math:`k(v_i, v_j)`. + """ + S = kwargs["S"] + Kxx = (jax_gather_nd(self.evecs, x) * S[None, :]) @ jnp.transpose( + jax_gather_nd(self.evecs, y) + ) # shape (n,n) + return Kxx.squeeze() + + def _initialise_params(self, key: PRNGKeyType) -> Dict: + return { + "lengthscale": jnp.array([1.0] * self.ndims), + "variance": jnp.array([1.0]), + "smoothness": jnp.array([1.0]), + } + + @property + def num_vertex(self) -> int: + return self.compute_engine.num_vertex + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the squared_distance function", +) +def squared_distance( + x: Float[Array, "1 D"], y: Float[Array, "1 D"] +) -> Float[Array, "1"]: + """Compute the squared distance between a pair of inputs. + + Args: + x (Float[Array, "1 D"]): First input. + y (Float[Array, "1 D"]): Second input. + + Returns: + Float[Array, "1"]: The squared distance between the inputs. + """ + + return jnp.sum((x - y) ** 2) + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the euclidean_distance function", +) +def euclidean_distance( + x: Float[Array, "1 D"], y: Float[Array, "1 D"] +) -> Float[Array, "1"]: + """Compute the euclidean distance between a pair of inputs. + + Args: + x (Float[Array, "1 D"]): First input. + y (Float[Array, "1 D"]): Second input. + + Returns: + Float[Array, "1"]: The euclidean distance between the inputs. + """ + + return jnp.sqrt(jnp.maximum(squared_distance(x, y), 1e-36)) + + +@deprecation.deprecated( + deprecated_in="0.5.5", + removed_in="0.6.0", + details="Use JaxKern for the jax_gather_nd function", +) +def jax_gather_nd(params, indices): + tuple_indices = tuple(indices[..., i] for i in range(indices.shape[-1])) + return params[tuple_indices] __all__ = [ From cf98c90135d2cbcc48317cda8d3339e0517b7128 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Thu, 5 Jan 2023 18:21:17 +0000 Subject: [PATCH 08/15] Fix failing test --- gpjax/kernels.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gpjax/kernels.py b/gpjax/kernels.py index 7800d1b1..460c5461 100644 --- a/gpjax/kernels.py +++ b/gpjax/kernels.py @@ -1001,9 +1001,9 @@ def cross_covariance( return matrix -@deprecation.deprecated( - deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" -) +# @deprecation.deprecated( +# deprecated_in="0.5.5", removed_in="0.6.0", details="Use JaxKern for the GraphKernel" +# ) class GraphKernel(AbstractKernel): def __init__( self, From a7427480c387a411f610668c12a2fd4011c90f6c Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 10:03:03 +0000 Subject: [PATCH 09/15] Resolve cacheing --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index fa484290..d8cd6b7f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -92,6 +92,7 @@ jobs: - python/install-packages: pkg-manager: pip-dist path-args: .[dev] + pypi-cache: false - run: name: Run tests command: | From e57d0e88cb840ae571fb65b5d8673dab4e7d8c99 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:03:21 +0000 Subject: [PATCH 10/15] Fix cacheing --- .circleci/config.yml | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index d8cd6b7f..fa1de6ff 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -83,16 +83,15 @@ jobs: resource_class: large steps: - checkout - - restore_cache: - keys: - - pip-cache + - python/load-cache - run: - name: Update pip + name: Install dependencies command: pip install --upgrade pip - python/install-packages: pkg-manager: pip-dist path-args: .[dev] - pypi-cache: false + # pypi-cache: false + - python/save-cache - run: name: Run tests command: | @@ -104,10 +103,6 @@ jobs: curl -Os https://uploader.codecov.io/v0.1.0_4653/linux/codecov chmod +x codecov ./codecov -t ${CODECOV_TOKEN} - - save_cache: - key: pip-cache - paths: - - ~/.cache/pip - store_test_results: path: test-results - store_artifacts: From df42abcdbc231b63dc7805c614211b46632160d0 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:16:12 +0000 Subject: [PATCH 11/15] Fix cacheing --- .circleci/config.yml | 13 ++++--------- requirements/dev.txt | 7 +++++++ requirements/requirements.txt | 11 +++++++++++ 3 files changed, 22 insertions(+), 9 deletions(-) create mode 100644 requirements/dev.txt create mode 100644 requirements/requirements.txt diff --git a/.circleci/config.yml b/.circleci/config.yml index fa1de6ff..f8363fdc 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,7 +1,7 @@ version: 2.1 orbs: - python: circleci/python@2.1.1 + # python: circleci/python@2.1.1 codecov: codecov/codecov@3.2.2 commands: @@ -51,7 +51,6 @@ commands: - run: name: Upload to PyPI command: twine upload dist/* -r << parameters.pkgname >> --verbose - install_pandoc: description: "Install pandoc" parameters: @@ -83,15 +82,11 @@ jobs: resource_class: large steps: - checkout - - python/load-cache - run: name: Install dependencies - command: pip install --upgrade pip - - python/install-packages: - pkg-manager: pip-dist - path-args: .[dev] - # pypi-cache: false - - python/save-cache + command: | + pip install --upgrade pip + pip install -r requirements/dev.txt - run: name: Run tests command: | diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 00000000..b6081cc7 --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,7 @@ +black +isort +pylint +flake8 +pytest +networkx +pytest-cov \ No newline at end of file diff --git a/requirements/requirements.txt b/requirements/requirements.txt new file mode 100644 index 00000000..a8b062aa --- /dev/null +++ b/requirements/requirements.txt @@ -0,0 +1,11 @@ +jax>=0.4.1 +jaxlib>=0.4.1 +optax +jaxutils +jaxkern +distrax>=0.1.2 +tqdm>=4.0.0 +ml-collections==0.1.0 +jaxtyping>=0.0.2 +jaxlinop>=0.0.3 +deprecation \ No newline at end of file From 3fb3567a2045853ba85aac87c17d0110cf3f3dbe Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:17:40 +0000 Subject: [PATCH 12/15] Fix cacheing --- .circleci/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/config.yml b/.circleci/config.yml index f8363fdc..6f8e9a23 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -87,6 +87,7 @@ jobs: command: | pip install --upgrade pip pip install -r requirements/dev.txt + python setup.py install - run: name: Run tests command: | From ac417ae21b3df4fac7a1f44d9b218ff8f0ab4984 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:19:44 +0000 Subject: [PATCH 13/15] Fix cacheing --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 6f8e9a23..727100a2 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -87,7 +87,7 @@ jobs: command: | pip install --upgrade pip pip install -r requirements/dev.txt - python setup.py install + pip install -e . - run: name: Run tests command: | From 7e7b2d1c1c2e8d49f4e8d839ba98a7e7049a33b8 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:34:56 +0000 Subject: [PATCH 14/15] Bump jaxutils --- setup.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/setup.py b/setup.py index 02a28816..9d9a2b83 100644 --- a/setup.py +++ b/setup.py @@ -22,20 +22,6 @@ def get_versions(): return versions -REQUIRES = [ - "jax>=0.4.1", - "jaxlib>=0.4.1", - "optax", - "jaxutils", - "jaxkern", - "distrax>=0.1.2", - "tqdm>=4.0.0", - "ml-collections==0.1.0", - "jaxtyping>=0.0.2", - "jaxlinop>=0.0.3", - "deprecation", -] - EXTRAS = { "dev": [ "black", @@ -65,7 +51,20 @@ def get_versions(): "Documentation": "https://gpjax.readthedocs.io/en/latest/", "Source": "https://github.com/thomaspinder/GPJax", }, - install_requires=REQUIRES, + python_requires=">=3.7", + install_requires=[ + "jax>=0.4.1", + "jaxlib>=0.4.1", + "optax", + "jaxutils>=0.0.5", + "jaxkern", + "distrax>=0.1.2", + "tqdm>=4.0.0", + "ml-collections==0.1.0", + "jaxtyping>=0.0.2", + "jaxlinop>=0.0.3", + "deprecation", + ], tests_require=EXTRAS["dev"], extras_require=EXTRAS, keywords=["gaussian-processes jax machine-learning bayesian"], From bb63fea273da20e7a517aa3063b00c76da2c76e5 Mon Sep 17 00:00:00 2001 From: Thomas Pinder Date: Sun, 8 Jan 2023 14:52:22 +0000 Subject: [PATCH 15/15] Bump JaxUtils --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 9d9a2b83..d98ddd57 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ def get_versions(): "jax>=0.4.1", "jaxlib>=0.4.1", "optax", - "jaxutils>=0.0.5", + "jaxutils>=0.0.6", "jaxkern", "distrax>=0.1.2", "tqdm>=4.0.0",