Skip to content

Commit

Permalink
Re-add logs for new debugging asserts only
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Oct 15, 2024
1 parent 8e22a02 commit 143afe5
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions trieste/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,12 +656,20 @@ def one_hot_encoder(self) -> EncoderFunction:

def binary_encoder(x: TensorType) -> TensorType:
# no need to one-hot encode binary categories (but we should still validate)
tf.debugging.Assert(tf.reduce_all((x == 0) | (x == 1)), [tf.constant([])])
tf.debugging.Assert(
tf.reduce_all((x == 0) | (x == 1)),
["Invalid binary values for one-hot encoding:", x],
)
return x

def encoder(x: TensorType) -> TensorType:
flat_x, unflatten = flatten_leading_dims(x)
tf.debugging.assert_equal(flat_x.shape[-1], len(self.tags))
tf.debugging.assert_equal(
flat_x.shape[-1],
len(self.tags),
message="Invalid input for one-hot encoding: "
f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}",
)
columns = tf.split(flat_x, flat_x.shape[-1], axis=1)
encoders = [
(
Expand Down

0 comments on commit 143afe5

Please sign in to comment.