diff --git a/utils/metrics.py b/utils/metrics.py index 6bba4cfe2a42..9bf084c78854 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -139,6 +139,12 @@ def process_batch(self, detections, labels): Returns: None, updates confusion matrix accordingly """ + if detections is None: + gt_classes = labels.int() + for i, gc in enumerate(gt_classes): + self.matrix[self.nc, gc] += 1 # background FN + return + detections = detections[detections[:, 4] > self.conf] gt_classes = labels[:, 0].int() detection_classes = detections[:, 5].int() diff --git a/val.py b/val.py index b0cc8e7f1577..48207a1130a6 100644 --- a/val.py +++ b/val.py @@ -228,6 +228,8 @@ def run( if npr == 0: if nl: stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0])) + if plots: + confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) continue # Predictions