diff --git a/val.py b/val.py index 78abbda8231a..d0c09719f6be 100644 --- a/val.py +++ b/val.py @@ -79,16 +79,17 @@ def process_batch(detections, labels, iouv): """ correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) iou = box_iou(labels[:, 1:], detections[:, :4]) - x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match - if x[0].shape[0]: - matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou] - if x[0].shape[0] > 1: - matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 1], return_index=True)[1]] - # matches = matches[matches[:, 2].argsort()[::-1]] - matches = matches[np.unique(matches[:, 0], return_index=True)[1]] - matches = torch.Tensor(matches).to(iouv.device) - correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv + for i in range(len(iouv)): + x = torch.where((iou >= iouv[i]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + matches = torch.Tensor(matches).to(iouv.device) + correct[matches[:, 1].long(), i] = True return correct @@ -206,6 +207,7 @@ def run(data, # Metrics for si, pred in enumerate(out): + pred = pred[:100] labels = targets[targets[:, 0] == si, 1:] nl = len(labels) tcls = labels[:, 0].tolist() if nl else [] # target class