Skip to content

Commit

Permalink
Fixed ROI Pooling Output Shape to Consider Multiple ROIs (#2350) (#2360)
Browse files Browse the repository at this point in the history
* Fixed indentation and output shape in roi pooling to consider multiple ROIs

* Formatted code
  • Loading branch information
VarunS1997 authored Feb 27, 2024
1 parent e170bfe commit 18b8d79
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 16 deletions.
19 changes: 11 additions & 8 deletions keras_cv/layers/object_detection/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ def _pool_single_sample(self, args):
feature_map: [H, W, C] float Tensor
rois: [N, 4] float Tensor
Returns:
pooled_feature_map: [target_size, C] float Tensor
pooled_feature_map: [N, target_height, target_width, C] float Tensor
"""
feature_map, rois = args
num_rois = rois.get_shape().as_list()[0]
height, width, channel = feature_map.get_shape().as_list()
regions = []
# TODO (consider vectorize it for better performance)
for n in range(num_rois):
# [4]
Expand All @@ -127,7 +128,7 @@ def _pool_single_sample(self, args):
region_width = width * (roi[3] - roi[1])
h_step = region_height / self.target_height
w_step = region_width / self.target_width
regions = []
region_steps = []
for i in range(self.target_height):
for j in range(self.target_width):
height_start = y_start + i * h_step
Expand All @@ -147,16 +148,18 @@ def _pool_single_sample(self, args):
1, width_end - width_start
)
# [h_step, w_step, C]
region = feature_map[
region_step = feature_map[
height_start:height_end, width_start:width_end, :
]
# target_height * target_width * [C]
regions.append(tf.reduce_max(region, axis=[0, 1]))
regions = tf.reshape(
tf.stack(regions),
[self.target_height, self.target_width, channel],
region_steps.append(tf.reduce_max(region_step, axis=[0, 1]))
regions.append(
tf.reshape(
tf.stack(region_steps),
[self.target_height, self.target_width, channel],
)
)
return regions
return tf.stack(regions)

def get_config(self):
config = {
Expand Down
50 changes: 42 additions & 8 deletions keras_cv/layers/object_detection/roi_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_no_quantize(self):
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([27, 31, 59, 63]), [1, 2, 2, 1]
tf.constant([27, 31, 59, 63]), [1, 1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -69,7 +69,7 @@ def test_roi_quantize_y(self):
# | 56, 57, 58(max) | 59, 60, 61, 62(max) | 63 (removed)
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([26, 30, 58, 62]), [1, 2, 2, 1]
tf.constant([26, 30, 58, 62]), [1, 1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -94,7 +94,7 @@ def test_roi_quantize_x(self):
# | 48, 49, 50, 51(max) | 52, 53, 54, 55(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([19, 23, 51, 55]), [1, 2, 2, 1]
tf.constant([19, 23, 51, 55]), [1, 1, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -121,7 +121,7 @@ def test_roi_quantize_h(self):
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([11, 15, 35, 39, 59, 63]), [1, 3, 2, 1]
tf.constant([11, 15, 35, 39, 59, 63]), [1, 1, 3, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -147,7 +147,7 @@ def test_roi_quantize_w(self):
# | 56, 57(max) | 58, 59, 60(max) | 61, 62, 63(max) |
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([25, 28, 31, 57, 60, 63]), [1, 2, 3, 1]
tf.constant([25, 28, 31, 57, 60, 63]), [1, 1, 2, 3, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -168,7 +168,8 @@ def test_roi_feature_map_height_smaller_than_roi(self):
# ------------------repeated----------------------
# | 12, 13(max) | 14, 15(max) |
expected_feature_map = tf.reshape(
tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]), [1, 6, 2, 1]
tf.constant([1, 3, 1, 3, 5, 7, 9, 11, 9, 11, 13, 15]),
[1, 1, 6, 2, 1],
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -189,7 +190,7 @@ def test_roi_feature_map_width_smaller_than_roi(self):
# --------------------------------------------
expected_feature_map = tf.reshape(
tf.constant([4, 4, 5, 6, 6, 7, 12, 12, 13, 14, 14, 15]),
[1, 2, 6, 1],
[1, 1, 2, 6, 1],
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

Expand All @@ -203,10 +204,43 @@ def test_roi_empty(self):
rois = tf.reshape(tf.constant([0.0, 0.0, 0.0, 0.0]), [1, 1, 4])
pooled_feature_map = roi_pooler(feature_map, rois)
# all outputs should be top-left pixel
self.assertAllClose(tf.ones([1, 2, 2, 1]), pooled_feature_map)
self.assertAllClose(tf.ones([1, 1, 2, 2, 1]), pooled_feature_map)

def test_invalid_image_shape(self):
with self.assertRaisesRegex(ValueError, "dynamic shape"):
_ = ROIPooler(
"rel_yxyx", target_size=[2, 2], image_shape=[None, 224, 3]
)

def test_multiple_rois(self):
feature_map = tf.expand_dims(
tf.reshape(tf.range(0, 64), [8, 8, 1]), axis=0
)

roi_pooler = ROIPooler(
bounding_box_format="yxyx",
target_size=[2, 2],
image_shape=[224, 224, 3],
)
rois = tf.constant(
[[[0.0, 0.0, 112.0, 112.0], [0.0, 112.0, 224.0, 224.0]]],
)

pooled_feature_map = roi_pooler(feature_map, rois)
# the maximum value would be at bottom-right at each block, roi sharded
# into 2x2 blocks
# | 0, 1, 2, 3 | 4, 5, 6, 7 |
# | 8, 9, 10, 11 | 12, 13, 14, 15 |
# | 16, 17, 18, 19 | 20, 21, 22, 23 |
# | 24, 25, 26, 27(max) | 28, 29, 30, 31(max) |
# --------------------------------------------
# | 32, 33, 34, 35 | 36, 37, 38, 39 |
# | 40, 41, 42, 43 | 44, 45, 46, 47 |
# | 48, 49, 50, 51 | 52, 53, 54, 55 |
# | 56, 57, 58, 59(max) | 60, 61, 62, 63(max) |
# --------------------------------------------

expected_feature_map = tf.reshape(
tf.constant([9, 11, 25, 27, 29, 31, 61, 63]), [1, 2, 2, 2, 1]
)
self.assertAllClose(expected_feature_map, pooled_feature_map)

0 comments on commit 18b8d79

Please sign in to comment.