Skip to content

Commit

Permalink
Align NMS-seg closer to NMS
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Aug 20, 2022
1 parent 2d9394d commit 1a84f47
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions utils/segment/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,32 @@ def non_max_suppression_masks(
max_det=300,
mask_dim=32,
):
"""Runs Non-Maximum Suppression (NMS) on inference results
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""

bs = prediction.shape[0] # batch size
nc = prediction.shape[2] - 5 # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Checks
assert (0 <= conf_thres <= 1), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert (0 <= iou_thres <= 1), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'

# Settings
# min_wh = 2 # (pixels) minimum box width and height
max_wh = 7680 # (pixels) maximum box width and height
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
time_limit = 0.6 + 0.06 * bs # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
nm = 5 + mask_dim

t = time.time()
output = [torch.zeros((0, 6 + mask_dim), device=prediction.device)] * prediction.shape[0]
output = [torch.zeros((0, 6 + mask_dim), device=prediction.device)] * bs
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
Expand Down Expand Up @@ -92,8 +93,6 @@ def non_max_suppression_masks(
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
else:
x = x[x[:, 4].argsort(descending=True)] # sort by confidence

# Batched NMS
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
Expand Down

0 comments on commit 1a84f47

Please sign in to comment.