Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Add a switch for POST_NMS per batch/image during training #695

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@
# all FPN levels
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
# Apply the post NMS per batch (default) or per image during training
# (default is True to be consistent with Detectron, see Issue #672)
_C.MODEL.RPN.FPN_POST_NMS_PER_BATCH = True
# Custom rpn head, empty to use default conv or separable conv
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"

Expand Down
10 changes: 7 additions & 3 deletions maskrcnn_benchmark/modeling/rpn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
min_size,
box_coder=None,
fpn_post_nms_top_n=None,
fpn_post_nms_per_batch=True,
):
"""
Arguments:
Expand All @@ -47,6 +48,7 @@ def __init__(
if fpn_post_nms_top_n is None:
fpn_post_nms_top_n = post_nms_top_n
self.fpn_post_nms_top_n = fpn_post_nms_top_n
self.fpn_post_nms_per_batch = fpn_post_nms_per_batch

def add_gt_proposals(self, proposals, targets):
"""
Expand Down Expand Up @@ -154,9 +156,9 @@ def select_over_all_levels(self, boxlists):
# different behavior during training and during testing:
# during training, post_nms_top_n is over *all* the proposals combined, while
# during testing, it is over the proposals for each image
# TODO resolve this difference and make it consistent. It should be per image,
# and not per batch
if self.training:
# NOTE: it should be per image, and not per batch. However, to be consistent
# with Detectron, the default is per batch (see Issue #672)
if self.training and self.fpn_post_nms_per_batch:
objectness = torch.cat(
[boxlist.get_field("objectness") for boxlist in boxlists], dim=0
)
Expand Down Expand Up @@ -189,6 +191,7 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train):
if not is_train:
pre_nms_top_n = config.MODEL.RPN.PRE_NMS_TOP_N_TEST
post_nms_top_n = config.MODEL.RPN.POST_NMS_TOP_N_TEST
fpn_post_nms_per_batch = config.MODEL.RPN.FPN_POST_NMS_PER_BATCH
nms_thresh = config.MODEL.RPN.NMS_THRESH
min_size = config.MODEL.RPN.MIN_SIZE
box_selector = RPNPostProcessor(
Expand All @@ -198,5 +201,6 @@ def make_rpn_postprocessor(config, rpn_box_coder, is_train):
min_size=min_size,
box_coder=rpn_box_coder,
fpn_post_nms_top_n=fpn_post_nms_top_n,
fpn_post_nms_per_batch=fpn_post_nms_per_batch,
)
return box_selector