diff --git a/tensorboard/plugins/image/summary_test.py b/tensorboard/plugins/image/summary_test.py index 6edfe0bb537..8eb9273701f 100644 --- a/tensorboard/plugins/image/summary_test.py +++ b/tensorboard/plugins/image/summary_test.py @@ -237,6 +237,25 @@ def test_floating_point_data(self): # range [0, 255] with 229 = 0.9 * 255, truncated. self.assertAllEqual([0, 0, 229, 255, 255], list(decoded.flat)) + def test_vector_data(self): + data = np.array( + [ + [-0.01, 0.0, 0.9, 1.0, 1.1], + [-0.01, 0.0, 1.0, 0.9, 1.1], + ] + ).reshape((2, -1, 1, 1)) + pb = self.image("mona_lisa", data) + + encoded0 = pb.value[0].tensor.string_val[2] # skip width, height + decoded0 = tf.image.decode_png(encoded0).numpy() + # Float values outside [0, 1) are truncated, and everything is scaled to the + # range [0, 255] with 229 = 0.9 * 255, truncated. + self.assertAllEqual([0, 0, 229, 255, 255], list(decoded0.flat)) + + encoded1 = pb.value[0].tensor.string_val[3] # skip width, height + decoded1 = tf.image.decode_png(encoded1).numpy() + self.assertAllEqual([0, 0, 255, 229, 255], list(decoded1.flat)) + if __name__ == "__main__": tf.test.main() diff --git a/tensorboard/plugins/image/summary_v2.py b/tensorboard/plugins/image/summary_v2.py index 480d5d88c3b..8f969df0db8 100644 --- a/tensorboard/plugins/image/summary_v2.py +++ b/tensorboard/plugins/image/summary_v2.py @@ -113,18 +113,24 @@ def lazy_tensor(): tf.debugging.assert_non_negative(max_outputs) images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True) limited_images = images[:max_outputs] - encoded_images = tf.map_fn( - tf.image.encode_png, - limited_images, - dtype=tf.string, - name="encode_each_image", - ) - # Workaround for map_fn returning float dtype for an empty elems input. - encoded_images = tf.cond( - tf.shape(input=encoded_images)[0] > 0, - lambda: encoded_images, - lambda: tf.constant([], tf.string), - ) + if tf.compat.forward_compatible(2023, 5, 1): + encoded_images = tf.image.encode_png(limited_images) + else: + # TODO(b/276803093): The kernel was updated around 2023/04/15. + # After 90 days (2023/07/15), please remove the False branch. + encoded_images = tf.map_fn( + tf.image.encode_png, + limited_images, + dtype=tf.string, + name="encode_each_image", + ) + # Workaround for map_fn returning float dtype for an empty + # elems input. + encoded_images = tf.cond( + tf.shape(input=encoded_images)[0] > 0, + lambda: encoded_images, + lambda: tf.constant([], tf.string), + ) image_shape = tf.shape(input=images) dimensions = tf.stack( [