Skip to content

Commit

Permalink
Inherit remaining predict, sample, predict_y shape checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 19, 2023
1 parent 00a5503 commit 3c66343
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tests/unit/models/gpflux/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from check_shapes import inherit_check_shapes
from gpflow.conditionals.util import sample_mvn
from gpflux.helpers import construct_basic_inducing_variables, construct_basic_kernel
from gpflux.layers import GPLayer
Expand Down Expand Up @@ -57,6 +58,7 @@ def model_keras(self) -> tf.keras.Model:
def optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._optimizer

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
# Taken from GPflow implementation of `GPModel.predict_f_samples` in gpflow.models.model
mean, cov = self._model_gpflux.predict_f(query_points, full_cov=True)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from check_shapes import inherit_check_shapes

from tests.util.misc import (
FixedAcquisitionRule,
Expand Down Expand Up @@ -596,6 +597,7 @@ def __init__(self, data: Dataset):
super().__init__()
self._data = data

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, var = super().predict(query_points)
return mean, var / len(self._data)
Expand Down
5 changes: 5 additions & 0 deletions tests/util/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import inherit_check_shapes
from gpflow.models import GPR, SGPR, SVGP, VGP, GPModel
from typing_extensions import Protocol

Expand Down Expand Up @@ -70,6 +71,7 @@ def optimize(self, dataset: Dataset) -> None:
class GaussianMarginal(ProbabilisticModel):
"""A probabilistic model with Gaussian marginal distribution. Assumes events of shape [N]."""

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
mean, var = self.predict(query_points)
samples = tfp.distributions.Normal(mean, tf.sqrt(var)).sample(num_samples)
Expand All @@ -95,6 +97,7 @@ def __init__(
def __repr__(self) -> str:
return f"GaussianProcess({self._mean_functions!r}, {self._kernels!r})"

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = self.predict_joint(query_points[..., None, :])
return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1])
Expand Down Expand Up @@ -131,6 +134,7 @@ def __init__(
def __repr__(self) -> str:
return f"GaussianProcessWithoutNoise({self._mean_functions!r}, {self._kernels!r})"

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = self.predict_joint(query_points[..., None, :])
return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1])
Expand Down Expand Up @@ -278,6 +282,7 @@ def covariance_with_top_fidelity(self, x: TensorType) -> TensorType:
mean, _ = self.predict(x)
return tf.ones_like(mean, dtype=mean.dtype) # dummy covariances of correct shape

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
fmean, fvar = self.predict(query_points)
yvar = fvar + tf.constant(1.0, dtype=fmean.dtype) # dummy noise variance
Expand Down
4 changes: 4 additions & 0 deletions trieste/acquisition/function/greedy_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gpflow
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import inherit_check_shapes
from typing_extensions import Protocol, runtime_checkable

from ...data import Dataset
Expand Down Expand Up @@ -652,6 +653,7 @@ def update_fantasized_data(self, fantasized_data: Dataset) -> None:
self._fantasized_query_points.assign(fantasized_data.query_points)
self._fantasized_observations.assign(fantasized_data.observations)

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
This function wraps conditional_predict_f. It cannot directly call
Expand Down Expand Up @@ -690,6 +692,7 @@ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (t

return _broadcast_predict(query_points, fun)

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
This function wraps conditional_predict_f_sample. It cannot directly call
Expand All @@ -716,6 +719,7 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
) # [B, ..., S, L]
return _restore_leading_dim(samples, leading_dim)

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
This function wraps conditional_predict_y. It cannot directly call
Expand Down
3 changes: 3 additions & 0 deletions trieste/models/gpflux/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from abc import ABC, abstractmethod

import tensorflow as tf
from check_shapes import inherit_check_shapes
from gpflow.base import Module

from ...types import TensorType
Expand Down Expand Up @@ -58,6 +59,7 @@ def optimizer(self) -> KerasOptimizer:
"""The optimizer wrapper for training the model."""
return self._optimizer

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""Note: unless otherwise noted, this returns the mean and variance of the last layer
conditioned on one sample from the previous layers."""
Expand All @@ -67,6 +69,7 @@ def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
raise NotImplementedError

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""Note: unless otherwise noted, this will return the prediction conditioned on one sample
from the lower layers."""
Expand Down
2 changes: 2 additions & 0 deletions trieste/models/gpflux/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dill
import gpflow
import tensorflow as tf
from check_shapes import inherit_check_shapes
from gpflow.inducing_variables import InducingPoints
from gpflux.layers import GPLayer, LatentVariableLayer
from gpflux.models import DeepGP
Expand Down Expand Up @@ -277,6 +278,7 @@ def model_gpflux(self) -> DeepGP:
def model_keras(self) -> tf.keras.Model:
return self._model_keras

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
trajectory = self.trajectory_sampler().get_trajectory()
expanded_query_points = tf.expand_dims(query_points, -2) # [N, 1, D]
Expand Down
5 changes: 4 additions & 1 deletion trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import gpflow
import tensorflow as tf
from check_shapes import check_shapes
from check_shapes import check_shapes, inherit_check_shapes
from typing_extensions import Protocol, runtime_checkable

from ..data import Dataset
Expand Down Expand Up @@ -373,6 +373,7 @@ def __init__(
"""
self._models, self._event_sizes = zip(*(model_with_event_size,) + models_with_event_sizes)

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
r"""
:param query_points: The points at which to make predictions, of shape [..., D].
Expand All @@ -384,6 +385,7 @@ def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
means, vars_ = zip(*[model.predict(query_points) for model in self._models])
return tf.concat(means, axis=-1), tf.concat(vars_, axis=-1)

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
r"""
:param query_points: The points at which to sample, with shape [..., N, D].
Expand All @@ -395,6 +397,7 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
samples = [model.sample(query_points, num_samples) for model in self._models]
return tf.concat(samples, axis=-1)

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
r"""
:param query_points: The points at which to make predictions, of shape [..., D].
Expand Down
3 changes: 3 additions & 0 deletions trieste/models/keras/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import inherit_check_shapes
from typing_extensions import Protocol, runtime_checkable

from ...types import TensorType
Expand Down Expand Up @@ -60,9 +61,11 @@ def optimizer(self) -> KerasOptimizer:
"""The optimizer wrapper for training the model."""
return self._optimizer

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
return self.model.predict(query_points)

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
raise NotImplementedError(
"""
Expand Down
3 changes: 3 additions & 0 deletions trieste/models/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_probability.python.distributions as tfd
from check_shapes import inherit_check_shapes
from tensorflow.python.keras.callbacks import Callback

from ... import logging
Expand Down Expand Up @@ -227,6 +228,7 @@ def ensemble_distributions(self, query_points: TensorType) -> tuple[tfd.Distribu
x_transformed: dict[str, TensorType] = self.prepare_query_points(query_points)
return self._model.model(x_transformed)

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
r"""
Returns mean and variance at ``query_points`` for the whole ensemble.
Expand Down Expand Up @@ -291,6 +293,7 @@ def predict_ensemble(self, query_points: TensorType) -> tuple[TensorType, Tensor

return predicted_means, predicted_vars

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
Return ``num_samples`` samples at ``query_points``. We use the mixture approximation in
Expand Down

0 comments on commit 3c66343

Please sign in to comment.