diff --git a/val.py b/val.py index 58e8170da86c..4c1d7d26b0de 100644 --- a/val.py +++ b/val.py @@ -50,26 +50,27 @@ def save_one_json(predn, jdict, path, class_map): 'score': round(p[4], 5)}) -def process_batch(predictions, labels, iouv): - # Evaluate 1 batch of predictions - correct = torch.zeros(predictions.shape[0], len(iouv), dtype=torch.bool, device=iouv.device) - detected = [] # label indices - tcls, pcls = labels[:, 0], predictions[:, 5] - nl = labels.shape[0] # number of labels - for cls in torch.unique(tcls): - ti = (cls == tcls).nonzero().view(-1) # label indices - pi = (cls == pcls).nonzero().view(-1) # prediction indices - if pi.shape[0]: # find detections - ious, i = box_iou(predictions[pi, 0:4], labels[ti, 1:5]).max(1) # best ious, indices - detected_set = set() - for j in (ious > iouv[0]).nonzero(): - d = ti[i[j]] # detected label - if d.item() not in detected_set: - detected_set.add(d.item()) - detected.append(d) # append detections - correct[pi[j]] = ious[j] > iouv # iou_thres is 1xn - if len(detected) == nl: # all labels already located in image - break +def process_batch(detections, labels, iouv): + """ + Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format. + Arguments: + detections (Array[N, 6]), x1, y1, x2, y2, conf, class + labels (Array[M, 5]), class, x1, y1, x2, y2 + Returns: + correct (Array[N, 10]), for 10 IoU levels + """ + correct = torch.zeros(detections.shape[0], iouv.shape[0], dtype=torch.bool, device=iouv.device) + iou = box_iou(labels[:, 1:], detections[:, :4]) + x = torch.where((iou >= iouv[0]) & (labels[:, 0:1] == detections[:, 5])) # IoU above threshold and classes match + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detection, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 1], return_index=True)[1]] + # matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[np.unique(matches[:, 0], return_index=True)[1]] + matches = torch.Tensor(matches).to(iouv.device) + correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv return correct