From 143afe5397f8cd0be2a34a6bf16f60c2f58d3b5e Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 15 Oct 2024 11:53:00 +0100 Subject: [PATCH] Re-add logs for new debugging asserts only --- trieste/space.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/trieste/space.py b/trieste/space.py index 0de9214b5..04b95fad1 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -656,12 +656,20 @@ def one_hot_encoder(self) -> EncoderFunction: def binary_encoder(x: TensorType) -> TensorType: # no need to one-hot encode binary categories (but we should still validate) - tf.debugging.Assert(tf.reduce_all((x == 0) | (x == 1)), [tf.constant([])]) + tf.debugging.Assert( + tf.reduce_all((x == 0) | (x == 1)), + ["Invalid binary values for one-hot encoding:", x], + ) return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) - tf.debugging.assert_equal(flat_x.shape[-1], len(self.tags)) + tf.debugging.assert_equal( + flat_x.shape[-1], + len(self.tags), + message="Invalid input for one-hot encoding: " + f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}", + ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ (