Skip to content

Commit

Permalink
Update experimental.py
Browse files Browse the repository at this point in the history
  • Loading branch information
WongKinYiu committed Mar 7, 2024
1 parent 51feedd commit 380284c
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions models/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,27 @@ def __init__(self, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=640, de
self.n_classes=n_classes

def forward(self, x):
boxes = x[:, :, :4]
conf = x[:, :, 4:5]
scores = x[:, :, 5:]
if self.n_classes == 1:
scores = conf # for models with one class, cls_loss is 0 and cls_conf is always 0.5,
# so there is no need to multiplicate.
else:
scores *= conf # conf = obj_conf * cls_conf
boxes @= self.convert_matrix
## https://github.com/thaitc-hust/yolov9-tensorrt/blob/main/torch2onnx.py
## thanks https://github.com/thaitc-hust
if isinstance(x, list): ## yolov9-c.pt and yolov9-e.pt return list
x = x[1]
x = x.permute(0, 2, 1)
bboxes_x = x[..., 0:1]
bboxes_y = x[..., 1:2]
bboxes_w = x[..., 2:3]
bboxes_h = x[..., 3:4]
bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
obj_conf = x[..., 4:]
scores = obj_conf
bboxes @= self.convert_matrix
max_score, category_id = scores.max(2, keepdim=True)
dis = category_id.float() * self.max_wh
nmsbox = boxes + dis
nmsbox = bboxes + dis
max_score_tp = max_score.transpose(1, 2).contiguous()
selected_indices = ORT_NMS.apply(nmsbox, max_score_tp, self.max_obj, self.iou_threshold, self.score_threshold)
X, Y = selected_indices[:, 0], selected_indices[:, 2]
selected_boxes = boxes[X, Y, :]
selected_boxes = bboxes[X, Y, :]
selected_categories = category_id[X, Y, :].float()
selected_scores = max_score[X, Y, :]
X = X.unsqueeze(1).float()
Expand Down

0 comments on commit 380284c

Please sign in to comment.