Skip to content

Commit

Permalink
Cleanup non_maximum_suppression
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Jun 11, 2022
1 parent 2593050 commit b41d859
Showing 1 changed file with 7 additions and 26 deletions.
33 changes: 7 additions & 26 deletions yolort/relay/nms_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@
import torch
import torchvision
from torch import Tensor
from yolort.models.box_head import _decode_pred_logits


def batched_nms(
def non_maximum_suppression(
prediction: Tensor,
score_thresh: float = 0.25,
nms_thresh: float = 0.45,
agnostic: bool = False,
):
"""
Runs Non-Maximum Suppression (NMS) on inference results
Runs Non-Maximum Suppression (NMS) on inference results.
Returns:
list of detections, on (n, 6) tensor per image [xyxy, conf, cls]
Expand All @@ -24,37 +25,17 @@ def batched_nms(
output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
x = x[xc[xi]] # confidence
# Compute conf
cx, cy, w, h = x[:, 0:1], x[:, 1:2], x[:, 2:3], x[:, 3:4]
obj_conf = x[:, 4:5]
cls_conf = x[:, 5:]
cls_conf = obj_conf * cls_conf # conf = obj_conf * cls_conf
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(cx, cy, w, h)
conf, j = cls_conf.max(1, keepdim=True)
box, score = _decode_pred_logits(x)
conf, j = score.max(1, keepdim=True)
# best class only
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > score_thresh]
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
i = torchvision.ops.nms(boxes, scores, nms_thresh) # NMS
i = torchvision.ops.nms(boxes, scores, nms_thresh)
output[xi] = x[i]
return output


def xywh2xyxy(cx, cy, w, h):
"""
This function is used while exporting ONNX models
Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
"""
halfw = w / 2
halfh = h / 2
xmin = cx - halfw # top left x
ymin = cy - halfh # top left y
xmax = cx + halfw # bottom right x
ymax = cy + halfh # bottom right y
return torch.cat((xmin, ymin, xmax, ymax), 1)


class NonMaxSupressionOp(torch.autograd.Function):
@staticmethod
def forward(ctx, boxes, scores, detections_per_class, iou_thresh, score_thresh):
Expand Down

0 comments on commit b41d859

Please sign in to comment.