From 0f2d9b25683784e34bae28234c3d50c79763ff7a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 27 Jan 2021 13:04:22 -0800 Subject: [PATCH] Metric-Confidence plots feature addition --- utils/metrics.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/utils/metrics.py b/utils/metrics.py index 99d5bcfaf2af..f97ed3268d60 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -38,9 +38,9 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision # Create Precision-Recall curve and compute AP for each class px, py = np.linspace(0, 1, 1000), [] # for plotting - pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 + # pr_score = np.linspace(0, 1, 100) # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95) - ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s) + ap, p, r = np.zeros(s), np.zeros(s + [1000]), np.zeros(s + [1000]) for ci, c in enumerate(unique_classes): i = pred_cls == c n_l = (target_cls == c).sum() # number of labels @@ -55,11 +55,11 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision # Recall recall = tpc / (n_l + 1e-16) # recall curve - r[ci] = np.interp(-pr_score, -conf[i], recall[:, 0]) # r at pr_score, negative x, xp because xp decreases + r[ci, 0] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases # Precision precision = tpc / (tpc + fpc) # precision curve - p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score + p[ci, 0] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score # AP from recall-precision curve for j in range(tp.shape[1]): @@ -72,6 +72,9 @@ def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='precision if plot: plot_pr_curve(px, py, ap, save_dir, names) + plot_mc_curve(px, f1[:, 0].T, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') + plot_mc_curve(px, p[:, 0].T, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') + plot_mc_curve(px, r[:, 0].T, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') return p, r, ap, f1, unique_classes.astype('int32') @@ -182,12 +185,13 @@ def print(self): # Plots ---------------------------------------------------------------------------------------------------------------- def plot_pr_curve(px, py, ap, save_dir='.', names=()): + # Precision-recall curve fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) py = np.stack(py, axis=1) - if 0 < len(names) < 21: # show mAP in legend if < 10 classes + if 0 < len(names) < 21: # display per-class legend if < 21 classes for i, y in enumerate(py.T): - ax.plot(px, y, linewidth=1, label=f'{names[i]} %.3f' % ap[i, 0]) # plot(recall, precision) + ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) else: ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) @@ -198,3 +202,23 @@ def plot_pr_curve(px, py, ap, save_dir='.', names=()): ax.set_ylim(0, 1) plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") fig.savefig(Path(save_dir) / 'precision_recall_curve.png', dpi=250) + + +def plot_mc_curve(px, py, save_dir='.', names=(), xlabel='Confidence', ylabel='Metric'): + # Metric-confidence curve + fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) + + y = py.mean(1) + ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + fig.savefig(save_dir, dpi=250)