Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use tf.debugging for search space assertions #299

Merged
merged 6 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions tests/unit/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __mul__(self, other: Integers) -> Integers:
@pytest.mark.parametrize("exponent", [0, -2])
def test_search_space___pow___raises_for_non_positive_exponent(exponent: int) -> None:
space = Integers(3)
with pytest.raises(ValueError):
with pytest.raises(tf.errors.InvalidArgumentError):
space ** exponent


Expand All @@ -59,7 +59,7 @@ def _points_in_2D_search_space() -> tf.Tensor:

@pytest.mark.parametrize("shape", various_shapes(excluding_ranks=[2]))
def test_discrete_search_space_raises_for_invalid_shapes(shape: ShapeLike) -> None:
with pytest.raises(ValueError):
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
DiscreteSearchSpace(tf.random.uniform(shape))


Expand Down Expand Up @@ -99,7 +99,7 @@ def test_discrete_search_space_contains_raises_for_invalid_shapes(
points: tf.Tensor, test_point: tf.Tensor
) -> None:
space = DiscreteSearchSpace(points)
with pytest.raises(ValueError):
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
_ = test_point in space


Expand Down Expand Up @@ -127,7 +127,7 @@ def test_discrete_search_space_sampling_raises_when_too_many_samples_are_request
) -> None:
search_space = DiscreteSearchSpace(_points_in_2D_search_space())

with pytest.raises(ValueError, match="samples"):
with pytest.raises(tf.errors.InvalidArgumentError):
search_space.sample(num_samples)


Expand Down Expand Up @@ -212,7 +212,7 @@ def test_box_raises_if_bounds_have_invalid_shape(
) -> None:
lower, upper = tf.zeros(lower_shape), tf.ones(upper_shape)

with pytest.raises(ValueError):
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
Box(lower, upper)


Expand Down Expand Up @@ -297,7 +297,7 @@ def test_box_contains_raises_on_point_of_different_shape(
box = Box(tf.zeros(bound_shape), tf.ones(bound_shape))
point = tf.zeros(point_shape)

with pytest.raises(ValueError):
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
_ = point in box


Expand Down Expand Up @@ -396,7 +396,7 @@ def test_box_discretize_returns_search_space_with_correct_number_of_points(

assert len(samples) == num_samples

with pytest.raises(ValueError):
with pytest.raises(tf.errors.InvalidArgumentError):
dss.sample(num_samples + 1)


Expand Down
52 changes: 22 additions & 30 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
""" This module contains implementations of various types of search space. """
from __future__ import annotations

import operator
from abc import ABC, abstractmethod
from functools import reduce
from typing import Optional, Sequence, TypeVar, overload

import tensorflow as tf
Expand Down Expand Up @@ -45,8 +47,8 @@ def __contains__(self, value: TensorType) -> bool | TensorType:
:param value: A point to check for membership of this :class:`SearchSpace`.
:return: `True` if ``value`` is a member of this search space, else `False`. May return a
scalar boolean `TensorType` instead of the `bool` itself.
:raise ValueError (or InvalidArgumentError): If ``value`` has a different dimensionality
from this :class:`SearchSpace`.
:raise ValueError (or tf.errors.InvalidArgumentError): If ``value`` has a different
dimensionality from this :class:`SearchSpace`.
"""

@abstractmethod
Expand All @@ -65,17 +67,10 @@ def __pow__(self: SP, other: int) -> SP:
:param other: The exponent, or number of instances of this search space to multiply
together. Must be strictly positive.
:return: The Cartesian product of ``other`` instances of this search space.
:raise ValueError: If the exponent ``other`` is less than 1.
:raise tf.errors.InvalidArgumentError: If the exponent ``other`` is less than 1.
"""
if other < 1:
raise ValueError(f"Exponent must be strictly positive, got {other}")

space = self

for _ in range(other - 1):
space *= self

return space
tf.debugging.assert_positive(other, message="Exponent must be strictly positive")
return reduce(operator.mul, [self] * other)
Copy link
Contributor Author

@joelberkeley joelberkeley Jul 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

felt like doing this, can undo if you'd prefer

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely more elegant!
am I missing something or assert_positive means >0 while here we want >=1?

Copy link
Contributor Author

@joelberkeley joelberkeley Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. maybe i though positive meant strictly positive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scrap that ... > 0 is the same as >=1 for integers, and that's also what assert_positive does



class DiscreteSearchSpace(SearchSpace):
Expand All @@ -95,7 +90,7 @@ class DiscreteSearchSpace(SearchSpace):
def __init__(self, points: TensorType):
"""
:param points: The points that define the discrete space, with shape ('N', 'D').
:raise ValueError (or InvalidArgumentError): If ``points`` has an invalid shape.
:raise ValueError (or tf.errors.InvalidArgumentError): If ``points`` has an invalid shape.
"""
tf.debugging.assert_shapes([(points, ("N", "D"))])
self._points = points
Expand All @@ -119,13 +114,11 @@ def sample(self, num_samples: int) -> TensorType:
:return: ``num_samples`` i.i.d. random points, sampled uniformly, and without replacement,
from this search space.
"""
num_points = self._points.shape[0]
if num_samples > num_points:
raise ValueError(
"Number of samples cannot be greater than the number of points"
f" {num_points} in discrete search space, got {num_samples}"
)

tf.debugging.assert_less_equal(
num_samples,
len(self._points),
message="Number of samples cannot be greater than the number of points in search space",
)
return tf.random.shuffle(self._points)[:num_samples, :]

def __mul__(self, other: DiscreteSearchSpace) -> DiscreteSearchSpace:
Expand Down Expand Up @@ -187,7 +180,7 @@ def __init__(
and if a tensor, must have float type.
:param upper: The upper (inclusive) bounds of the box. Must have shape [D] for positive D,
and if a tensor, must have float type.
:raise ValueError (or InvalidArgumentError): If any of the following are true:
:raise ValueError (or tf.errors.InvalidArgumentError): If any of the following are true:

- ``lower`` and ``upper`` have invalid shapes.
- ``lower`` and ``upper`` do not have the same floating point type.
Expand All @@ -198,8 +191,7 @@ def __init__(
tf.assert_rank(lower, 1)
tf.assert_rank(upper, 1)

if len(lower) == 0:
raise ValueError(f"Bounds must have shape [D] for positive D, got {tf.shape(lower)}.")
tf.debugging.assert_positive(len(lower), message="bounds cannot be empty")

if isinstance(lower, Sequence):
self._lower = tf.constant(lower, dtype=tf.float64)
Expand Down Expand Up @@ -235,14 +227,14 @@ def __contains__(self, value: TensorType) -> bool | TensorType:
:param value: A point to check for membership of this :class:`SearchSpace`.
:return: `True` if ``value`` is a member of this search space, else `False`. May return a
scalar boolean `TensorType` instead of the `bool` itself.
:raise ValueError (or InvalidArgumentError): If ``value`` has a different dimensionality
from the search space.
:raise ValueError (or tf.errors.InvalidArgumentError): If ``value`` has a different
dimensionality from the search space.
"""
if not shapes_equal(value, self._lower):
raise ValueError(
f"value must have same dimensionality as search space: {self._lower.shape},"
f" got shape {value.shape}"
)
tf.debugging.assert_equal(
shapes_equal(value, self._lower),
True,
message="value must have same dimensionality as search space",
)

return tf.reduce_all(value >= self._lower) and tf.reduce_all(value <= self._upper)

Expand Down