diff --git a/tests/unit/models/gpflux/test_interface.py b/tests/unit/models/gpflux/test_interface.py index a9a5be302c..910a208033 100644 --- a/tests/unit/models/gpflux/test_interface.py +++ b/tests/unit/models/gpflux/test_interface.py @@ -18,6 +18,7 @@ import numpy.testing as npt import pytest import tensorflow as tf +from check_shapes import inherit_check_shapes from gpflow.conditionals.util import sample_mvn from gpflux.helpers import construct_basic_inducing_variables, construct_basic_kernel from gpflux.layers import GPLayer @@ -57,6 +58,7 @@ def model_keras(self) -> tf.keras.Model: def optimizer(self) -> tf.keras.optimizers.Optimizer: return self._optimizer + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: # Taken from GPflow implementation of `GPModel.predict_f_samples` in gpflow.models.model mean, cov = self._model_gpflux.predict_f(query_points, full_cov=True) diff --git a/tests/unit/test_bayesian_optimizer.py b/tests/unit/test_bayesian_optimizer.py index dead19ac21..9e15032d6b 100644 --- a/tests/unit/test_bayesian_optimizer.py +++ b/tests/unit/test_bayesian_optimizer.py @@ -21,6 +21,7 @@ import numpy.testing as npt import pytest import tensorflow as tf +from check_shapes import inherit_check_shapes from tests.util.misc import ( FixedAcquisitionRule, @@ -596,6 +597,7 @@ def __init__(self, data: Dataset): super().__init__() self._data = data + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, var = super().predict(query_points) return mean, var / len(self._data) diff --git a/tests/util/models/gpflow/models.py b/tests/util/models/gpflow/models.py index 2f9e9f888a..e676f812ba 100644 --- a/tests/util/models/gpflow/models.py +++ b/tests/util/models/gpflow/models.py @@ -21,6 +21,7 @@ import numpy as np import tensorflow as tf import tensorflow_probability as tfp +from check_shapes import inherit_check_shapes from gpflow.models import GPR, SGPR, SVGP, VGP, GPModel from typing_extensions import Protocol @@ -70,6 +71,7 @@ def optimize(self, dataset: Dataset) -> None: class GaussianMarginal(ProbabilisticModel): """A probabilistic model with Gaussian marginal distribution. Assumes events of shape [N].""" + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: mean, var = self.predict(query_points) samples = tfp.distributions.Normal(mean, tf.sqrt(var)).sample(num_samples) @@ -95,6 +97,7 @@ def __init__( def __repr__(self) -> str: return f"GaussianProcess({self._mean_functions!r}, {self._kernels!r})" + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = self.predict_joint(query_points[..., None, :]) return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1]) @@ -131,6 +134,7 @@ def __init__( def __repr__(self) -> str: return f"GaussianProcessWithoutNoise({self._mean_functions!r}, {self._kernels!r})" + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: mean, cov = self.predict_joint(query_points[..., None, :]) return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1]) @@ -278,6 +282,7 @@ def covariance_with_top_fidelity(self, x: TensorType) -> TensorType: mean, _ = self.predict(x) return tf.ones_like(mean, dtype=mean.dtype) # dummy covariances of correct shape + @inherit_check_shapes def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: fmean, fvar = self.predict(query_points) yvar = fvar + tf.constant(1.0, dtype=fmean.dtype) # dummy noise variance diff --git a/trieste/acquisition/function/greedy_batch.py b/trieste/acquisition/function/greedy_batch.py index afa57b64b7..336b9ea779 100644 --- a/trieste/acquisition/function/greedy_batch.py +++ b/trieste/acquisition/function/greedy_batch.py @@ -21,6 +21,7 @@ import gpflow import tensorflow as tf import tensorflow_probability as tfp +from check_shapes import inherit_check_shapes from typing_extensions import Protocol, runtime_checkable from ...data import Dataset @@ -652,6 +653,7 @@ def update_fantasized_data(self, fantasized_data: Dataset) -> None: self._fantasized_query_points.assign(fantasized_data.query_points) self._fantasized_observations.assign(fantasized_data.observations) + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ This function wraps conditional_predict_f. It cannot directly call @@ -690,6 +692,7 @@ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (t return _broadcast_predict(query_points, fun) + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: """ This function wraps conditional_predict_f_sample. It cannot directly call @@ -716,6 +719,7 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: ) # [B, ..., S, L] return _restore_leading_dim(samples, leading_dim) + @inherit_check_shapes def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """ This function wraps conditional_predict_y. It cannot directly call diff --git a/trieste/models/gpflux/interface.py b/trieste/models/gpflux/interface.py index 0ce8059b2b..1c2b5297c4 100644 --- a/trieste/models/gpflux/interface.py +++ b/trieste/models/gpflux/interface.py @@ -17,6 +17,7 @@ from abc import ABC, abstractmethod import tensorflow as tf +from check_shapes import inherit_check_shapes from gpflow.base import Module from ...types import TensorType @@ -58,6 +59,7 @@ def optimizer(self) -> KerasOptimizer: """The optimizer wrapper for training the model.""" return self._optimizer + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """Note: unless otherwise noted, this returns the mean and variance of the last layer conditioned on one sample from the previous layers.""" @@ -67,6 +69,7 @@ def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: def sample(self, query_points: TensorType, num_samples: int) -> TensorType: raise NotImplementedError + @inherit_check_shapes def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: """Note: unless otherwise noted, this will return the prediction conditioned on one sample from the lower layers.""" diff --git a/trieste/models/gpflux/models.py b/trieste/models/gpflux/models.py index 9560680845..1c4d3bf59a 100644 --- a/trieste/models/gpflux/models.py +++ b/trieste/models/gpflux/models.py @@ -19,6 +19,7 @@ import dill import gpflow import tensorflow as tf +from check_shapes import inherit_check_shapes from gpflow.inducing_variables import InducingPoints from gpflux.layers import GPLayer, LatentVariableLayer from gpflux.models import DeepGP @@ -277,6 +278,7 @@ def model_gpflux(self) -> DeepGP: def model_keras(self) -> tf.keras.Model: return self._model_keras + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: trajectory = self.trajectory_sampler().get_trajectory() expanded_query_points = tf.expand_dims(query_points, -2) # [N, 1, D] diff --git a/trieste/models/interfaces.py b/trieste/models/interfaces.py index 09f2b95b77..89c88c2461 100644 --- a/trieste/models/interfaces.py +++ b/trieste/models/interfaces.py @@ -19,7 +19,7 @@ import gpflow import tensorflow as tf -from check_shapes import check_shapes +from check_shapes import check_shapes, inherit_check_shapes from typing_extensions import Protocol, runtime_checkable from ..data import Dataset @@ -373,6 +373,7 @@ def __init__( """ self._models, self._event_sizes = zip(*(model_with_event_size,) + models_with_event_sizes) + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: r""" :param query_points: The points at which to make predictions, of shape [..., D]. @@ -384,6 +385,7 @@ def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: means, vars_ = zip(*[model.predict(query_points) for model in self._models]) return tf.concat(means, axis=-1), tf.concat(vars_, axis=-1) + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: r""" :param query_points: The points at which to sample, with shape [..., N, D]. @@ -395,6 +397,7 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType: samples = [model.sample(query_points, num_samples) for model in self._models] return tf.concat(samples, axis=-1) + @inherit_check_shapes def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]: r""" :param query_points: The points at which to make predictions, of shape [..., D]. diff --git a/trieste/models/keras/interface.py b/trieste/models/keras/interface.py index b7e24e9e22..8f3c28c927 100644 --- a/trieste/models/keras/interface.py +++ b/trieste/models/keras/interface.py @@ -19,6 +19,7 @@ import tensorflow as tf import tensorflow_probability as tfp +from check_shapes import inherit_check_shapes from typing_extensions import Protocol, runtime_checkable from ...types import TensorType @@ -60,9 +61,11 @@ def optimizer(self) -> KerasOptimizer: """The optimizer wrapper for training the model.""" return self._optimizer + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: return self.model.predict(query_points) + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: raise NotImplementedError( """ diff --git a/trieste/models/keras/models.py b/trieste/models/keras/models.py index 2020653636..91c2006487 100644 --- a/trieste/models/keras/models.py +++ b/trieste/models/keras/models.py @@ -21,6 +21,7 @@ import tensorflow as tf import tensorflow_probability as tfp import tensorflow_probability.python.distributions as tfd +from check_shapes import inherit_check_shapes from tensorflow.python.keras.callbacks import Callback from ... import logging @@ -227,6 +228,7 @@ def ensemble_distributions(self, query_points: TensorType) -> tuple[tfd.Distribu x_transformed: dict[str, TensorType] = self.prepare_query_points(query_points) return self._model.model(x_transformed) + @inherit_check_shapes def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]: r""" Returns mean and variance at ``query_points`` for the whole ensemble. @@ -291,6 +293,7 @@ def predict_ensemble(self, query_points: TensorType) -> tuple[TensorType, Tensor return predicted_means, predicted_vars + @inherit_check_shapes def sample(self, query_points: TensorType, num_samples: int) -> TensorType: """ Return ``num_samples`` samples at ``query_points``. We use the mixture approximation in