Skip to content

Commit

Permalink
Fixed issue with confidence for single class detectors when exporting (
Browse files Browse the repository at this point in the history
…WongKinYiu#607)

* Fixed issue with confidence for single class detectors when exporting

* Typo
  • Loading branch information
DMLON committed Sep 12, 2022
1 parent 77c6304 commit c21fd85
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
2 changes: 1 addition & 1 deletion export.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@
if opt.grid:
if opt.end2end:
print('\nStarting export end2end onnx model for %s...' % 'TensorRT' if opt.max_wh is None else 'onnxruntime')
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device)
model = End2End(model,opt.topk_all,opt.iou_thres,opt.conf_thres,opt.max_wh,device,len(labels))
if opt.end2end and opt.max_wh is None:
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [opt.batch_size, 1, opt.batch_size, opt.topk_all, 4,
Expand Down
22 changes: 16 additions & 6 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def symbolic(g,

class ONNX_ORT(nn.Module):
'''onnx module with ONNX-Runtime NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None):
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, device=None, n_classes=80):
super().__init__()
self.device = device if device else torch.device("cpu")
self.max_obj = torch.tensor([max_obj]).to(device)
Expand All @@ -168,12 +168,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, de
self.convert_matrix = torch.tensor([[1, 0, 1, 0], [0, 1, 0, 1], [-0.5, 0, 0.5, 0], [0, -0.5, 0, 0.5]],
dtype=torch.float32,
device=self.device)
self.n_classes=n_classes

def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
if self.n_classes == 1:
scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
scores *= conf # conf = obj_conf * cls_conf
boxes @= self.convert_matrix
max_score, category_id = scores.max(2, keepdim=True)
dis = category_id.float() * self.max_wh
Expand All @@ -189,7 +194,7 @@ def forward(self, x):

class ONNX_TRT(nn.Module):
'''onnx module with TensorRT NMS operation.'''
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None):
def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
super().__init__()
assert max_wh is None
self.device = device if device else torch.device('cpu')
Expand All @@ -200,12 +205,17 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,d
self.plugin_version = '1'
self.score_activation = 0
self.score_threshold = score_thres
self.n_classes=n_classes

def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
scores *= conf
if self.n_classes == 1:
scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
scores *= conf # conf = obj_conf * cls_conf
num_det, det_boxes, det_scores, det_classes = TRT_NMS.apply(boxes, scores, self.background_class, self.box_coding,
self.iou_threshold, self.max_obj,
self.plugin_version, self.score_activation,
Expand All @@ -215,14 +225,14 @@ def forward(self, x):

class End2End(nn.Module):
'''export onnx or tensorrt model with NMS operation.'''
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None):
def __init__(self, model, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None, device=None, n_classes=80):
super().__init__()
device = device if device else torch.device('cpu')
assert isinstance(max_wh,(int)) or max_wh is None
self.model = model.to(device)
self.model.model[-1].end2end = True
self.patch_model = ONNX_TRT if max_wh is None else ONNX_ORT
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device)
self.end2end = self.patch_model(max_obj, iou_thres, score_thres, max_wh, device, n_classes)
self.end2end.eval()

def forward(self, x):
Expand Down

0 comments on commit c21fd85

Please sign in to comment.