diff --git a/gpjax/decision_making/__init__.py b/gpjax/decision_making/__init__.py index 6782f6b2..c3ff9226 100644 --- a/gpjax/decision_making/__init__.py +++ b/gpjax/decision_making/__init__.py @@ -21,6 +21,7 @@ AbstractAcquisitionMaximizer, ContinuousAcquisitionMaximizer, ) +from gpjax.decision_making.posterior_handler import PosteriorHandler from gpjax.decision_making.search_space import ( AbstractSearchSpace, ContinuousSearchSpace, @@ -42,6 +43,7 @@ "AbstractContinuousTestFunction", "Forrester", "LogarithmicGoldsteinPrice", + "PosteriorHandler", "Quadratic", "ThompsonSampling", ] diff --git a/gpjax/decision_making/posterior_handler.py b/gpjax/decision_making/posterior_handler.py new file mode 100644 index 00000000..06b67217 --- /dev/null +++ b/gpjax/decision_making/posterior_handler.py @@ -0,0 +1,155 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from dataclasses import dataclass + +from beartype.typing import ( + Callable, + Optional, +) +import optax as ox + +import gpjax as gpx +from gpjax.dataset import Dataset +from gpjax.gps import ( + AbstractLikelihood, + AbstractPosterior, + AbstractPrior, +) +from gpjax.objectives import AbstractObjective +from gpjax.typing import KeyArray + +LikelihoodBuilder = Callable[[int], AbstractLikelihood] +"""Type alias for likelihood builders, which take the number of datapoints as input and +return a likelihood object initialised with the given number of datapoints.""" + + +@dataclass +class PosteriorHandler: + """ + Class for handling the creation and updating of a GP posterior as new data is + observed. + + Attributes: + prior (AbstractPrior): Prior to use when forming the posterior. + likelihood_builder (LikelihoodBuilder): Function which takes the number of + datapoints as input and returns a likelihood object initialised with the given + number of datapoints. + optimization_objective (AbstractObjective): Objective to use for optimizing the + posterior hyperparameters. + optimizer (ox.GradientTransformation): Optax optimizer to use for optimizing the + posterior hyperparameters. + num_optimization_iterations (int): Number of iterations to optimize + the posterior hyperparameters for. + """ + + prior: AbstractPrior + likelihood_builder: LikelihoodBuilder + optimization_objective: AbstractObjective + optimizer: ox.GradientTransformation + num_optimization_iters: int + + def __post_init__(self): + if self.num_optimization_iters < 1: + raise ValueError("num_optimization_iters must be greater than 0.") + + def get_posterior( + self, dataset: Dataset, optimize: bool, key: Optional[KeyArray] = None + ) -> AbstractPosterior: + """ + Initialise (and optionally optimize) a posterior using the given dataset. + + Args: + dataset (Dataset): Dataset to get posterior for. + optimize (bool): Whether to optimize the posterior hyperparameters. + key (Optional[KeyArray]): A JAX PRNG key which is used for optimizing the posterior + hyperparameters. + + Returns: + Posterior for the given dataset. + """ + posterior = self.prior * self.likelihood_builder(dataset.n) + + if optimize: + if key is None: + raise ValueError( + "A key must be provided in order to optimize the posterior." + ) + posterior = self._optimize_posterior(posterior, dataset, key) + + return posterior + + def update_posterior( + self, + dataset: Dataset, + previous_posterior: AbstractPosterior, + optimize: bool, + key: Optional[KeyArray] = None, + ) -> AbstractPosterior: + """ + Update the given posterior with the given dataset. This needs to be done when + the number of datapoints in the (training) dataset of the posterior changes, as + the `AbstractLikelihood` class requires the number of datapoints to be specified. + Hyperparameters may or may not be optimized, depending on the value of the + `optimize` parameter. Note that the updated poterior will be initialised with + the same prior hyperparameters as the previous posterior, but the likelihood + will be re-initialised with the new number of datapoints, and hyperparameters + set as in the `likelihood_builder` function. + + Args: + dataset: Dataset to get posterior for. + previous_posterior: Posterior being updated. This is supplied as one may + wish to simply increase the number of datapoints in the likelihood, without + optimizing the posterior hyperparameters, in which case the previous + posterior can be used to obtain the previously set prior hyperparameters. + optimize: Whether to optimize the posterior hyperparameters. + key: A JAX PRNG key which is used for optimizing the posterior + hyperparameters. + """ + posterior = previous_posterior.prior * self.likelihood_builder(dataset.n) + + if optimize: + if key is None: + raise ValueError( + "A key must be provided in order to optimize the posterior." + ) + posterior = self._optimize_posterior(posterior, dataset, key) + return posterior + + def _optimize_posterior( + self, posterior: AbstractPosterior, dataset: Dataset, key: KeyArray + ) -> AbstractPosterior: + """ + Takes a posterior and corresponding dataset and optimizes the posterior using the + GPJax `fit` method. + + Args: + posterior: Posterior being optimized. + dataset: Dataset used for optimizing posterior. + key: A JAX PRNG key for generating random numbers. + Returns: + Optimized posterior. + """ + opt_posterior, _ = gpx.fit( + model=posterior, + objective=self.optimization_objective, + train_data=dataset, + optim=self.optimizer, + num_iters=self.num_optimization_iters, + safe=True, + key=key, + verbose=False, + ) + + return opt_posterior diff --git a/gpjax/decision_making/test_functions/__init__.py b/gpjax/decision_making/test_functions/__init__.py index 8fe075e1..c5016076 100644 --- a/gpjax/decision_making/test_functions/__init__.py +++ b/gpjax/decision_making/test_functions/__init__.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - from gpjax.decision_making.test_functions.continuous_functions import ( AbstractContinuousTestFunction, Forrester, LogarithmicGoldsteinPrice, Quadratic, ) +from gpjax.decision_making.test_functions.non_conjugate_functions import ( + PoissonTestFunction, +) __all__ = [ "AbstractContinuousTestFunction", "Forrester", "LogarithmicGoldsteinPrice", + "PoissonTestFunction", "Quadratic", ] diff --git a/gpjax/decision_making/test_functions/non_conjugate_functions.py b/gpjax/decision_making/test_functions/non_conjugate_functions.py new file mode 100644 index 00000000..94fa4ecb --- /dev/null +++ b/gpjax/decision_making/test_functions/non_conjugate_functions.py @@ -0,0 +1,90 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from abc import abstractmethod +from dataclasses import dataclass + +import jax.numpy as jnp +import jax.random as jr +from jaxtyping import ( + Array, + Float, + Integer, +) + +from gpjax.dataset import Dataset +from gpjax.decision_making.search_space import ContinuousSearchSpace +from gpjax.typing import KeyArray + + +@dataclass +class PoissonTestFunction: + """ + Test function for GPs utilising the Poisson likelihood. Function taken from + https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset. + + Attributes: + search_space (ContinuousSearchSpace): Search space for the function. + """ + + search_space = ContinuousSearchSpace( + lower_bounds=jnp.array([-2.0]), + upper_bounds=jnp.array([2.0]), + ) + + def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset: + """ + Generate a toy dataset from the test function. + + Args: + num_points (int): Number of points to sample. + key (KeyArray): JAX PRNG key. + + Returns: + Dataset: Dataset of points sampled from the test function. + """ + X = self.search_space.sample(num_points=num_points, key=key) + y = self.evaluate(X) + return Dataset(X=X, y=y) + + def generate_test_points( + self, num_points: int, key: KeyArray + ) -> Float[Array, "N D"]: + """ + Generate test points from the search space of the test function. + + Args: + num_points (int): Number of points to sample. + key (KeyArray): JAX PRNG key. + + Returns: + Float[Array, 'N D']: Test points sampled from the search space. + """ + return self.search_space.sample(num_points=num_points, key=key) + + @abstractmethod + def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]: + """ + Evaluate the test function at a set of points. Function taken from + https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset. + + Args: + x (Float[Array, 'N D']): Points to evaluate the test function at. + + Returns: + Integer[Array, 'N 1']: Values of the test function at the points. + """ + key = jr.PRNGKey(42) + f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x + return jr.poisson(key, jnp.exp(f(x))) diff --git a/tests/test_decision_making/test_posterior_handler.py b/tests/test_decision_making/test_posterior_handler.py new file mode 100644 index 00000000..71bf9a4b --- /dev/null +++ b/tests/test_decision_making/test_posterior_handler.py @@ -0,0 +1,336 @@ +# Copyright 2023 The JaxGaussianProcesses Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax import config + +config.update("jax_enable_x64", True) + +from beartype.typing import ( + Callable, + Union, +) +import jax.numpy as jnp +import jax.random as jr +import optax as ox +import pytest + +from gpjax.decision_making.posterior_handler import PosteriorHandler +from gpjax.decision_making.test_functions import ( + Forrester, + PoissonTestFunction, +) +from gpjax.gps import Prior +from gpjax.kernels import Matern52 +from gpjax.likelihoods import ( + AbstractLikelihood, + Gaussian, + Poisson, +) +from gpjax.mean_functions import Constant +from gpjax.objectives import ( + AbstractObjective, + ConjugateMLL, + NonConjugateMLL, +) + + +def gaussian_likelihood_builder(num_datapoints: int) -> Gaussian: + return Gaussian(num_datapoints=num_datapoints) + + +def poisson_likelihood_builder(num_datapoints: int) -> Poisson: + return Poisson(num_datapoints=num_datapoints) + + +@pytest.mark.parametrize("num_optimization_iters", [0, -1, -10]) +def test_posterior_handler_erroneous_num_optimization_iterations_raises_error( + num_optimization_iters: int, +): + mean_function = Constant() + kernel = Matern52() + prior = Prior(mean_function=mean_function, kernel=kernel) + likelihood_builder = gaussian_likelihood_builder + training_objective = ConjugateMLL(negative=True) + with pytest.raises(ValueError): + PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=num_optimization_iters, + ) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_get_optimized_posterior_with_no_key_raises_error(): + mean_function = Constant() + kernel = Matern52() + prior = Prior(mean_function=mean_function, kernel=kernel) + likelihood_builder = gaussian_likelihood_builder + training_objective = ConjugateMLL(negative=True) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + toy_function = Forrester() + dataset = toy_function.generate_dataset(num_points=5, key=jr.PRNGKey(42)) + with pytest.raises(ValueError): + posterior_handler.get_posterior(dataset=dataset, optimize=True) + + +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_update_and_optimize_posterior_with_no_key_raises_error(): + mean_function = Constant() + kernel = Matern52() + prior = Prior(mean_function=mean_function, kernel=kernel) + likelihood_builder = gaussian_likelihood_builder + training_objective = ConjugateMLL(negative=True) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + toy_function = Forrester() + dataset = toy_function.generate_dataset(num_points=5, key=jr.PRNGKey(42)) + initial_posterior = posterior_handler.get_posterior(dataset=dataset, optimize=False) + with pytest.raises(ValueError): + posterior_handler.update_posterior( + dataset=dataset, previous_posterior=initial_posterior, optimize=True + ) + + +@pytest.mark.parametrize("num_datapoints", [1, 50]) +@pytest.mark.parametrize( + "likelihood_builder, training_objective, test_function", + [ + (gaussian_likelihood_builder, ConjugateMLL(negative=True), Forrester()), + ( + poisson_likelihood_builder, + NonConjugateMLL(negative=True), + PoissonTestFunction(), + ), + ], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_get_posterior_no_optimization_correct_num_datapoints_and_not_optimized( + num_datapoints: int, + likelihood_builder: Callable[[int], AbstractLikelihood], + training_objective: AbstractObjective, + test_function: Union[Forrester, PoissonTestFunction], +): + mean_function = Constant(constant=jnp.array([1.0])) + kernel = Matern52(lengthscale=jnp.array([0.5]), variance=jnp.array(1.0)) + prior = Prior(mean_function=mean_function, kernel=kernel) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + dataset = test_function.generate_dataset( + num_points=num_datapoints, key=jr.PRNGKey(42) + ) + posterior = posterior_handler.get_posterior(dataset=dataset, optimize=False) + assert posterior.likelihood.num_datapoints == num_datapoints + assert posterior.prior.mean_function.constant == jnp.array([1.0]) + assert posterior.prior.kernel.lengthscale == jnp.array([0.5]) + assert posterior.prior.kernel.variance == jnp.array(1.0) + + +@pytest.mark.parametrize("num_datapoints", [5, 50]) +@pytest.mark.parametrize( + "likelihood_builder, training_objective, test_function", + [ + (gaussian_likelihood_builder, ConjugateMLL(negative=True), Forrester()), + ( + poisson_likelihood_builder, + NonConjugateMLL(negative=True), + PoissonTestFunction(), + ), + ], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_get_posterior_with_optimization_correct_num_datapoints_and_optimized( + num_datapoints: int, + likelihood_builder: Callable[[int], AbstractLikelihood], + training_objective: AbstractObjective, + test_function: Union[Forrester, PoissonTestFunction], +): + mean_function = Constant(constant=jnp.array([1.0])) + kernel = Matern52(lengthscale=jnp.array([0.5]), variance=jnp.array(1.0)) + prior = Prior(mean_function=mean_function, kernel=kernel) + non_optimized_posterior = prior * likelihood_builder(num_datapoints) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + dataset = test_function.generate_dataset( + num_points=num_datapoints, key=jr.PRNGKey(42) + ) + optimized_posterior = posterior_handler.get_posterior( + dataset=dataset, optimize=True, key=jr.PRNGKey(42) + ) + assert optimized_posterior.likelihood.num_datapoints == num_datapoints + assert optimized_posterior.prior.mean_function.constant != jnp.array([1.0]) + assert optimized_posterior.prior.kernel.lengthscale != jnp.array([0.5]) + assert optimized_posterior.prior.kernel.variance != jnp.array(1.0) + assert training_objective(optimized_posterior, dataset) < training_objective( + non_optimized_posterior, dataset + ) # Ensure optimization reduces training objective + + +@pytest.mark.parametrize("initial_num_datapoints", [5, 50]) +@pytest.mark.parametrize( + "likelihood_builder, training_objective, test_function", + [ + (gaussian_likelihood_builder, ConjugateMLL(negative=True), Forrester()), + ( + poisson_likelihood_builder, + NonConjugateMLL(negative=True), + PoissonTestFunction(), + ), + ], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_update_posterior_no_optimize_same_prior_parameters_and_different_num_datapoints( + initial_num_datapoints: int, + likelihood_builder: Callable[[int], AbstractLikelihood], + training_objective: AbstractObjective, + test_function: Union[Forrester, PoissonTestFunction], +): + mean_function = Constant(constant=jnp.array([1.0])) + kernel = Matern52(lengthscale=jnp.array([0.5]), variance=jnp.array(1.0)) + prior = Prior(mean_function=mean_function, kernel=kernel) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + initial_dataset = test_function.generate_dataset( + num_points=initial_num_datapoints, key=jr.PRNGKey(42) + ) + initial_posterior = posterior_handler.get_posterior( + dataset=initial_dataset, optimize=False + ) + updated_dataset = initial_dataset + test_function.generate_dataset( + num_points=1, key=jr.PRNGKey(42) + ) + assert updated_dataset.n == initial_dataset.n + 1 + updated_posterior = posterior_handler.update_posterior( + dataset=updated_dataset, previous_posterior=initial_posterior, optimize=False + ) + assert ( + updated_posterior.prior.kernel.lengthscale + == initial_posterior.prior.kernel.lengthscale + ) + assert ( + updated_posterior.prior.kernel.variance + == initial_posterior.prior.kernel.variance + ) + assert ( + updated_posterior.prior.mean_function.constant + == initial_posterior.prior.mean_function.constant + ) + assert updated_posterior.likelihood.num_datapoints == updated_dataset.n + + +@pytest.mark.parametrize("initial_num_datapoints", [5, 50]) +@pytest.mark.parametrize( + "likelihood_builder, training_objective, test_function", + [ + (gaussian_likelihood_builder, ConjugateMLL(negative=True), Forrester()), + ( + poisson_likelihood_builder, + NonConjugateMLL(negative=True), + PoissonTestFunction(), + ), + ], +) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_update_posterior_with_optimization_updated_prior_parameters_and_different_num_datapoints( + initial_num_datapoints: int, + likelihood_builder: Callable[[int], AbstractLikelihood], + training_objective: AbstractObjective, + test_function: Union[Forrester, PoissonTestFunction], +): + mean_function = Constant(constant=jnp.array([1.0])) + kernel = Matern52(lengthscale=jnp.array([0.5]), variance=jnp.array(1.0)) + prior = Prior(mean_function=mean_function, kernel=kernel) + posterior_handler = PosteriorHandler( + prior=prior, + likelihood_builder=likelihood_builder, + optimization_objective=training_objective, + optimizer=ox.adam(learning_rate=0.01), + num_optimization_iters=10, + ) + initial_dataset = test_function.generate_dataset( + num_points=initial_num_datapoints, key=jr.PRNGKey(42) + ) + initial_posterior = posterior_handler.get_posterior( + dataset=initial_dataset, optimize=False + ) + updated_dataset = initial_dataset + test_function.generate_dataset( + num_points=1, key=jr.PRNGKey(42) + ) + assert updated_dataset.n == initial_dataset.n + 1 + non_optimized_updated_posterior = posterior_handler.update_posterior( + dataset=updated_dataset, previous_posterior=initial_posterior, optimize=False + ) + optimized_updated_posterior = posterior_handler.update_posterior( + dataset=updated_dataset, + previous_posterior=initial_posterior, + optimize=True, + key=jr.PRNGKey(42), + ) + assert ( + optimized_updated_posterior.prior.kernel.lengthscale + != initial_posterior.prior.kernel.lengthscale + ) + assert ( + optimized_updated_posterior.prior.kernel.variance + != initial_posterior.prior.kernel.variance + ) + assert ( + optimized_updated_posterior.prior.mean_function.constant + != initial_posterior.prior.mean_function.constant + ) + assert optimized_updated_posterior.likelihood.num_datapoints == updated_dataset.n + assert training_objective( + optimized_updated_posterior, updated_dataset + ) < training_objective( + non_optimized_updated_posterior, updated_dataset + ) # Ensure optimization reduces training objective diff --git a/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py b/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py new file mode 100644 index 00000000..837abdbf --- /dev/null +++ b/tests/test_decision_making/test_test_functions/test_non_conjugate_functions.py @@ -0,0 +1,127 @@ +# Copyright 2023 The GPJax Contributors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from jax.config import config + +config.update("jax_enable_x64", True) + +import jax.numpy as jnp +import jax.random as jr +import pytest + +from gpjax.decision_making.test_functions import PoissonTestFunction +from gpjax.typing import KeyArray + + +@pytest.mark.parametrize("test_function", [PoissonTestFunction()]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_correct_dtypes(test_function: PoissonTestFunction): + dataset = test_function.generate_dataset(10, jr.PRNGKey(42)) + test_x = test_function.generate_test_points(10, jr.PRNGKey(42)) + assert dataset.X.dtype == jnp.float64 + assert dataset.y.dtype == jnp.integer + assert test_x.dtype == jnp.float64 + + +@pytest.mark.parametrize( + "test_function, dimensionality", + [(PoissonTestFunction(), 1)], +) +@pytest.mark.parametrize("num_samples", [1, 10, 100]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_test_points_shape( + test_function: PoissonTestFunction, dimensionality: int, num_samples: int +): + test_X = test_function.generate_test_points(num_samples, jr.PRNGKey(42)) + assert test_X.shape == (num_samples, dimensionality) + + +@pytest.mark.parametrize( + "test_function, dimensionality", + [(PoissonTestFunction(), 1)], +) +@pytest.mark.parametrize("num_samples", [1, 10, 100]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_dataset_shapes( + test_function: PoissonTestFunction, dimensionality: int, num_samples: int +): + dataset = test_function.generate_dataset(num_samples, jr.PRNGKey(42)) + assert dataset.X.shape == (num_samples, dimensionality) + assert dataset.y.shape == (num_samples, 1) + + +@pytest.mark.parametrize("test_function", [PoissonTestFunction()]) +@pytest.mark.parametrize("num_samples", [1, 10, 100]) +@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_same_key_same_dataset( + test_function: PoissonTestFunction, num_samples: int, key: KeyArray +): + dataset_one = test_function.generate_dataset(num_samples, key) + dataset_two = test_function.generate_dataset(num_samples, key) + assert jnp.equal(dataset_one.X, dataset_two.X).all() + assert jnp.equal(dataset_one.y, dataset_two.y).all() + + +@pytest.mark.parametrize("test_function", [PoissonTestFunction()]) +@pytest.mark.parametrize("num_samples", [10, 100]) +@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_different_key_different_dataset( + test_function: PoissonTestFunction, num_samples: int, key: KeyArray +): + dataset_one = test_function.generate_dataset(num_samples, key) + key, _ = jr.split(key) + dataset_two = test_function.generate_dataset(num_samples, key) + assert not jnp.equal(dataset_one.X, dataset_two.X).all() + assert not jnp.equal(dataset_one.y, dataset_two.y).all() + + +@pytest.mark.parametrize("test_function", [PoissonTestFunction()]) +@pytest.mark.parametrize("num_samples", [1, 10, 100]) +@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_same_key_same_test_points( + test_function: PoissonTestFunction, num_samples: int, key: KeyArray +): + test_points_one = test_function.generate_test_points(num_samples, key) + test_points_two = test_function.generate_test_points(num_samples, key) + assert jnp.equal(test_points_one, test_points_two).all() + + +@pytest.mark.parametrize("test_function", [PoissonTestFunction()]) +@pytest.mark.parametrize("num_samples", [1, 10, 100]) +@pytest.mark.parametrize("key", [jr.PRNGKey(42), jr.PRNGKey(10)]) +@pytest.mark.filterwarnings( + "ignore::UserWarning" +) # Sampling with tfp causes JAX to raise a UserWarning due to some internal logic around jnp.argsort +def test_different_key_different_test_points( + test_function: PoissonTestFunction, num_samples: int, key: KeyArray +): + test_points_one = test_function.generate_test_points(num_samples, key) + key, _ = jr.split(key) + test_points_two = test_function.generate_test_points(num_samples, key) + assert not jnp.equal(test_points_one, test_points_two).all()