diff --git a/utils/segment/general.py b/utils/segment/general.py index f1655e488944..075dce192ddf 100644 --- a/utils/segment/general.py +++ b/utils/segment/general.py @@ -20,31 +20,32 @@ def non_max_suppression_masks( max_det=300, mask_dim=32, ): - """Runs Non-Maximum Suppression (NMS) on inference results + """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections Returns: list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ + bs = prediction.shape[0] # batch size nc = prediction.shape[2] - 5 # number of classes xc = prediction[..., 4] > conf_thres # candidates # Checks - assert (0 <= conf_thres <= 1), f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0" - assert (0 <= iou_thres <= 1), f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0" + assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0' + assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0' # Settings # min_wh = 2 # (pixels) minimum box width and height max_wh = 7680 # (pixels) maximum box width and height max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() - time_limit = 10.0 # seconds to quit after + time_limit = 0.6 + 0.06 * bs # seconds to quit after redundant = True # require redundant detections multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS nm = 5 + mask_dim t = time.time() - output = [torch.zeros((0, 6 + mask_dim), device=prediction.device)] * prediction.shape[0] + output = [torch.zeros((0, 6 + mask_dim), device=prediction.device)] * bs for xi, x in enumerate(prediction): # image index, image inference # Apply constraints # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height @@ -92,8 +93,6 @@ def non_max_suppression_masks( continue elif n > max_nms: # excess boxes x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence - else: - x = x[x[:, 4].argsort(descending=True)] # sort by confidence # Batched NMS c = x[:, 5:6] * (0 if agnostic else max_wh) # classes