diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index f7b55c064d..b2638ca5c6 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -45,7 +45,7 @@ def __mul__(self, other: Integers) -> Integers: @pytest.mark.parametrize("exponent", [0, -2]) def test_search_space___pow___raises_for_non_positive_exponent(exponent: int) -> None: space = Integers(3) - with pytest.raises(ValueError): + with pytest.raises(tf.errors.InvalidArgumentError): space ** exponent @@ -59,7 +59,7 @@ def _points_in_2D_search_space() -> tf.Tensor: @pytest.mark.parametrize("shape", various_shapes(excluding_ranks=[2])) def test_discrete_search_space_raises_for_invalid_shapes(shape: ShapeLike) -> None: - with pytest.raises(ValueError): + with pytest.raises(TF_DEBUGGING_ERROR_TYPES): DiscreteSearchSpace(tf.random.uniform(shape)) @@ -99,7 +99,7 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes( points: tf.Tensor, test_point: tf.Tensor ) -> None: space = DiscreteSearchSpace(points) - with pytest.raises(ValueError): + with pytest.raises(TF_DEBUGGING_ERROR_TYPES): _ = test_point in space @@ -127,7 +127,7 @@ def test_discrete_search_space_sampling_raises_when_too_many_samples_are_request ) -> None: search_space = DiscreteSearchSpace(_points_in_2D_search_space()) - with pytest.raises(ValueError, match="samples"): + with pytest.raises(tf.errors.InvalidArgumentError): search_space.sample(num_samples) @@ -212,7 +212,7 @@ def test_box_raises_if_bounds_have_invalid_shape( ) -> None: lower, upper = tf.zeros(lower_shape), tf.ones(upper_shape) - with pytest.raises(ValueError): + with pytest.raises(TF_DEBUGGING_ERROR_TYPES): Box(lower, upper) @@ -297,7 +297,7 @@ def test_box_contains_raises_on_point_of_different_shape( box = Box(tf.zeros(bound_shape), tf.ones(bound_shape)) point = tf.zeros(point_shape) - with pytest.raises(ValueError): + with pytest.raises(TF_DEBUGGING_ERROR_TYPES): _ = point in box @@ -396,7 +396,7 @@ def test_box_discretize_returns_search_space_with_correct_number_of_points( assert len(samples) == num_samples - with pytest.raises(ValueError): + with pytest.raises(tf.errors.InvalidArgumentError): dss.sample(num_samples + 1) diff --git a/trieste/space.py b/trieste/space.py index fe2a794b58..6367b4eca8 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -14,7 +14,9 @@ """ This module contains implementations of various types of search space. """ from __future__ import annotations +import operator from abc import ABC, abstractmethod +from functools import reduce from typing import Optional, Sequence, TypeVar, overload import tensorflow as tf @@ -45,8 +47,8 @@ def __contains__(self, value: TensorType) -> bool | TensorType: :param value: A point to check for membership of this :class:`SearchSpace`. :return: `True` if ``value`` is a member of this search space, else `False`. May return a scalar boolean `TensorType` instead of the `bool` itself. - :raise ValueError (or InvalidArgumentError): If ``value`` has a different dimensionality - from this :class:`SearchSpace`. + :raise ValueError (or tf.errors.InvalidArgumentError): If ``value`` has a different + dimensionality from this :class:`SearchSpace`. """ @abstractmethod @@ -65,17 +67,10 @@ def __pow__(self: SP, other: int) -> SP: :param other: The exponent, or number of instances of this search space to multiply together. Must be strictly positive. :return: The Cartesian product of ``other`` instances of this search space. - :raise ValueError: If the exponent ``other`` is less than 1. + :raise tf.errors.InvalidArgumentError: If the exponent ``other`` is less than 1. """ - if other < 1: - raise ValueError(f"Exponent must be strictly positive, got {other}") - - space = self - - for _ in range(other - 1): - space *= self - - return space + tf.debugging.assert_positive(other, message="Exponent must be strictly positive") + return reduce(operator.mul, [self] * other) class DiscreteSearchSpace(SearchSpace): @@ -95,7 +90,7 @@ class DiscreteSearchSpace(SearchSpace): def __init__(self, points: TensorType): """ :param points: The points that define the discrete space, with shape ('N', 'D'). - :raise ValueError (or InvalidArgumentError): If ``points`` has an invalid shape. + :raise ValueError (or tf.errors.InvalidArgumentError): If ``points`` has an invalid shape. """ tf.debugging.assert_shapes([(points, ("N", "D"))]) self._points = points @@ -119,13 +114,11 @@ def sample(self, num_samples: int) -> TensorType: :return: ``num_samples`` i.i.d. random points, sampled uniformly, and without replacement, from this search space. """ - num_points = self._points.shape[0] - if num_samples > num_points: - raise ValueError( - "Number of samples cannot be greater than the number of points" - f" {num_points} in discrete search space, got {num_samples}" - ) - + tf.debugging.assert_less_equal( + num_samples, + len(self._points), + message="Number of samples cannot be greater than the number of points in search space", + ) return tf.random.shuffle(self._points)[:num_samples, :] def __mul__(self, other: DiscreteSearchSpace) -> DiscreteSearchSpace: @@ -187,7 +180,7 @@ def __init__( and if a tensor, must have float type. :param upper: The upper (inclusive) bounds of the box. Must have shape [D] for positive D, and if a tensor, must have float type. - :raise ValueError (or InvalidArgumentError): If any of the following are true: + :raise ValueError (or tf.errors.InvalidArgumentError): If any of the following are true: - ``lower`` and ``upper`` have invalid shapes. - ``lower`` and ``upper`` do not have the same floating point type. @@ -198,8 +191,7 @@ def __init__( tf.assert_rank(lower, 1) tf.assert_rank(upper, 1) - if len(lower) == 0: - raise ValueError(f"Bounds must have shape [D] for positive D, got {tf.shape(lower)}.") + tf.debugging.assert_positive(len(lower), message="bounds cannot be empty") if isinstance(lower, Sequence): self._lower = tf.constant(lower, dtype=tf.float64) @@ -235,14 +227,14 @@ def __contains__(self, value: TensorType) -> bool | TensorType: :param value: A point to check for membership of this :class:`SearchSpace`. :return: `True` if ``value`` is a member of this search space, else `False`. May return a scalar boolean `TensorType` instead of the `bool` itself. - :raise ValueError (or InvalidArgumentError): If ``value`` has a different dimensionality - from the search space. + :raise ValueError (or tf.errors.InvalidArgumentError): If ``value`` has a different + dimensionality from the search space. """ - if not shapes_equal(value, self._lower): - raise ValueError( - f"value must have same dimensionality as search space: {self._lower.shape}," - f" got shape {value.shape}" - ) + tf.debugging.assert_equal( + shapes_equal(value, self._lower), + True, + message="value must have same dimensionality as search space", + ) return tf.reduce_all(value >= self._lower) and tf.reduce_all(value <= self._upper)