diff --git a/tests/conftest.py b/tests/conftest.py index 1d02c2a6b6..5eb22f62f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,21 @@ from __future__ import annotations +from typing import Iterable + import pytest from _pytest.config import Config from _pytest.config.argparsing import Parser +from check_shapes import ( + DocstringFormat, + ShapeCheckingState, + get_enable_check_shapes, + get_enable_function_call_precompute, + get_rewrite_docstrings, + set_enable_check_shapes, + set_enable_function_call_precompute, + set_rewrite_docstrings, +) def pytest_addoption(parser: Parser) -> None: @@ -74,3 +86,18 @@ def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]) -> N import tensorflow as tf tf.config.experimental_run_functions_eagerly(True) + + +@pytest.fixture(autouse=True) +def enable_shape_checks() -> Iterable[None]: + # ensure `check_shapes` is always enabled for tests + old_enable = get_enable_check_shapes() + old_rewrite_docstrings = get_rewrite_docstrings() + old_function_call_precompute = get_enable_function_call_precompute() + set_enable_check_shapes(ShapeCheckingState.ENABLED) + set_rewrite_docstrings(DocstringFormat.SPHINX) + set_enable_function_call_precompute(True) + yield + set_enable_function_call_precompute(old_function_call_precompute) + set_rewrite_docstrings(old_rewrite_docstrings) + set_enable_check_shapes(old_enable) diff --git a/tests/unit/acquisition/function/test_function.py b/tests/unit/acquisition/function/test_function.py index 986092681e..454ea41eb8 100644 --- a/tests/unit/acquisition/function/test_function.py +++ b/tests/unit/acquisition/function/test_function.py @@ -22,6 +22,7 @@ import pytest import tensorflow as tf import tensorflow_probability as tfp +from check_shapes.exceptions import ShapeMismatchError from tests.util.misc import ( TF_DEBUGGING_ERROR_TYPES, @@ -282,7 +283,7 @@ def test_expected_improvement_switches_to_improvement_on_feasible_points() -> No def test_expected_improvement_raises_for_invalid_batch_size(at: TensorType) -> None: ei = expected_improvement(QuadraticMeanAndRBFKernel(), tf.constant([1.0])) - with pytest.raises(TF_DEBUGGING_ERROR_TYPES): + with pytest.raises(ShapeMismatchError): ei(at) @@ -946,7 +947,7 @@ def test_expected_constrained_improvement_raises_for_invalid_batch_size(at: Tens eci = builder.prepare_acquisition_function({NA: QuadraticMeanAndRBFKernel()}, datasets=data) - with pytest.raises(TF_DEBUGGING_ERROR_TYPES): + with pytest.raises(ShapeMismatchError): eci(at) diff --git a/trieste/acquisition/function/function.py b/trieste/acquisition/function/function.py index f98af45fa3..7550f952c3 100644 --- a/trieste/acquisition/function/function.py +++ b/trieste/acquisition/function/function.py @@ -21,6 +21,7 @@ import tensorflow as tf import tensorflow_probability as tfp +from check_shapes import check_shapes from ...data import Dataset from ...models import ProbabilisticModel, ReparametrizationSampler @@ -211,12 +212,12 @@ def update(self, eta: TensorType) -> None: """Update the acquisition function with a new eta value.""" self._eta.assign(eta) + @check_shapes( + "x: [N..., 1, D] # This acquisition function only supports batch sizes of one", + "return: [N..., L]", + ) @tf.function def __call__(self, x: TensorType) -> TensorType: - tf.debugging.assert_shapes( - [(x, [..., 1, None])], - message="This acquisition function only supports batch sizes of one.", - ) mean, variance = self._model.predict(tf.squeeze(x, -2)) normal = tfp.distributions.Normal(mean, tf.sqrt(variance)) return (self._eta - mean) * normal.cdf(self._eta) + variance * normal.prob(self._eta)