diff --git a/test.py b/test.py index 91b2b981c45b..2b9e90c05367 100644 --- a/test.py +++ b/test.py @@ -119,7 +119,7 @@ def test(data, targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling t = time_synchronized() - out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True) + out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) t1 += time_synchronized() - t # Statistics per image @@ -136,6 +136,8 @@ def test(data, continue # Predictions + if single_cls: + pred[:, 5] = 0 predn = pred.clone() scale_coords(img[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred