From 8f357c64bbb645ece9c0a5f7d6ae06c2055b1a42 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Fri, 11 Oct 2024 09:58:06 +0100 Subject: [PATCH 1/2] Make one hot encoder serialisable --- tests/unit/test_space.py | 4 ++-- trieste/space.py | 9 ++------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 7cbc34248..47e4d8759 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -1830,7 +1830,7 @@ def test_categorical_search_space_one_hot_encoding( pytest.param( CategoricalSearchSpace(["Y", "N"]), tf.constant([[0], [2], [1]]), - ValueError, + InvalidArgumentError, id="Out of range binary input value", ), pytest.param( @@ -1842,7 +1842,7 @@ def test_categorical_search_space_one_hot_encoding( pytest.param( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), tf.constant([[0], [1], [1]]), - ValueError, + InvalidArgumentError, id="Wrong input shape", ), ], diff --git a/trieste/space.py b/trieste/space.py index 4b32181cc..0de9214b5 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -656,17 +656,12 @@ def one_hot_encoder(self) -> EncoderFunction: def binary_encoder(x: TensorType) -> TensorType: # no need to one-hot encode binary categories (but we should still validate) - if tf.reduce_any((x != 0) & (x != 1)): - raise ValueError(f"Invalid values {tf.boolean_mask(x, ((x != 0) & (x != 1)))}") + tf.debugging.Assert(tf.reduce_all((x == 0) | (x == 1)), [tf.constant([])]) return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) - if flat_x.shape[-1] != len(self.tags): - raise ValueError( - "Invalid input for one-hot encoding: " - f"expected {len(self.tags)} tags, got {flat_x.shape[-1]}" - ) + tf.debugging.assert_equal(flat_x.shape[-1], len(self.tags)) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ ( From 7bf3d8e196715081a1eef243960c7e5e16098d45 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Fri, 11 Oct 2024 09:59:05 +0100 Subject: [PATCH 2/2] Bump version number --- CITATION.cff | 4 ++-- trieste/VERSION | 2 +- versions.json | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index f3e308373..7b13aabd0 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -30,6 +30,6 @@ authors: - family-names: "Picheny" given-names: "Victor" title: "Trieste" -version: 4.2.0 -date-released: 2024-09-20 +version: 4.2.1 +date-released: 2024-10-11 url: "https://github.com/secondmind-labs/trieste" diff --git a/trieste/VERSION b/trieste/VERSION index 6aba2b245..fae6e3d04 100644 --- a/trieste/VERSION +++ b/trieste/VERSION @@ -1 +1 @@ -4.2.0 +4.2.1 diff --git a/versions.json b/versions.json index 92cbc810c..efb8d7121 100644 --- a/versions.json +++ b/versions.json @@ -4,8 +4,8 @@ "url": "https://secondmind-labs.github.io/trieste/develop/" }, { - "version": "4.2.0", - "url": "https://secondmind-labs.github.io/trieste/4.2.0/" + "version": "4.2.1", + "url": "https://secondmind-labs.github.io/trieste/4.2.1/" }, { "version": "4.1.0",