diff --git a/trieste/space.py b/trieste/space.py index 0de9214b5..04b95fad1 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -656,12 +656,20 @@ def one_hot_encoder(self) -> EncoderFunction: def binary_encoder(x: TensorType) -> TensorType: # no need to one-hot encode binary categories (but we should still validate) - tf.debugging.Assert(tf.reduce_all((x == 0) | (x == 1)), [tf.constant([])]) + tf.debugging.Assert( + tf.reduce_all((x == 0) | (x == 1)), + ["Invalid binary values for one-hot encoding:", x], + ) return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) - tf.debugging.assert_equal(flat_x.shape[-1], len(self.tags)) + tf.debugging.assert_equal( + flat_x.shape[-1], + len(self.tags), + message="Invalid input for one-hot encoding: " + f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}", + ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ (