Skip to content

Commit

Permalink
Start using the vectorized EncodePng Op (#6344)
Browse files Browse the repository at this point in the history
## Motivation for features / changes
The vectorization unblocks running tf.summary.image under DTensor.

Vectorization of EncodePNG is added in
tensorflow/tensorflow@7c42227

## Technical description of changes
Applied a forward compatibility wrapper for 14 days since the EncodePNG change landed.

## Detailed steps to verify changes work correctly (as executed by you)
Verified by both branches work without DTensor, and the new (vectorized)
branch works under DTensor.
Note that I manually applied the internal patch to the OSS format. 

## Verification: 
cl/488391097
  • Loading branch information
rainwoodman committed May 3, 2023
1 parent 9c3c048 commit 7474a8d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
19 changes: 19 additions & 0 deletions tensorboard/plugins/image/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,25 @@ def test_floating_point_data(self):
# range [0, 255] with 229 = 0.9 * 255, truncated.
self.assertAllEqual([0, 0, 229, 255, 255], list(decoded.flat))

def test_vector_data(self):
data = np.array(
[
[-0.01, 0.0, 0.9, 1.0, 1.1],
[-0.01, 0.0, 1.0, 0.9, 1.1],
]
).reshape((2, -1, 1, 1))
pb = self.image("mona_lisa", data)

encoded0 = pb.value[0].tensor.string_val[2] # skip width, height
decoded0 = tf.image.decode_png(encoded0).numpy()
# Float values outside [0, 1) are truncated, and everything is scaled to the
# range [0, 255] with 229 = 0.9 * 255, truncated.
self.assertAllEqual([0, 0, 229, 255, 255], list(decoded0.flat))

encoded1 = pb.value[0].tensor.string_val[3] # skip width, height
decoded1 = tf.image.decode_png(encoded1).numpy()
self.assertAllEqual([0, 0, 255, 229, 255], list(decoded1.flat))


if __name__ == "__main__":
tf.test.main()
30 changes: 18 additions & 12 deletions tensorboard/plugins/image/summary_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,24 @@ def lazy_tensor():
tf.debugging.assert_non_negative(max_outputs)
images = tf.image.convert_image_dtype(data, tf.uint8, saturate=True)
limited_images = images[:max_outputs]
encoded_images = tf.map_fn(
tf.image.encode_png,
limited_images,
dtype=tf.string,
name="encode_each_image",
)
# Workaround for map_fn returning float dtype for an empty elems input.
encoded_images = tf.cond(
tf.shape(input=encoded_images)[0] > 0,
lambda: encoded_images,
lambda: tf.constant([], tf.string),
)
if tf.compat.forward_compatible(2023, 5, 1):
encoded_images = tf.image.encode_png(limited_images)
else:
# TODO(b/276803093): The kernel was updated around 2023/04/15.
# After 90 days (2023/07/15), please remove the False branch.
encoded_images = tf.map_fn(
tf.image.encode_png,
limited_images,
dtype=tf.string,
name="encode_each_image",
)
# Workaround for map_fn returning float dtype for an empty
# elems input.
encoded_images = tf.cond(
tf.shape(input=encoded_images)[0] > 0,
lambda: encoded_images,
lambda: tf.constant([], tf.string),
)
image_shape = tf.shape(input=images)
dimensions = tf.stack(
[
Expand Down

0 comments on commit 7474a8d

Please sign in to comment.