diff --git a/test.py b/test.py index 3bcbbd86f442..e0bb7726f7d1 100644 --- a/test.py +++ b/test.py @@ -30,9 +30,9 @@ def test(data, verbose=False, model=None, dataloader=None, - save_dir='', - merge=False, - save_txt=False): + save_dir=Path(''), # for saving images + save_txt=False, # for auto-labelling + plots=True): # Initialize/load model and set device training = model is not None if training: # called by train.py @@ -41,7 +41,7 @@ def test(data, else: # called directly set_logging() device = select_device(opt.device, batch_size=batch_size) - merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels + save_txt = opt.save_txt # save *.txt labels if save_txt: out = Path('inference/output') if os.path.exists(out): @@ -49,7 +49,7 @@ def test(data, os.makedirs(out) # make new output folder # Remove previous - for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')): + for f in glob.glob(str(save_dir / 'test_batch*.jpg')): os.remove(f) # Load model @@ -110,7 +110,7 @@ def test(data, # Run NMS t = time_synchronized() - output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge) + output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres) t1 += time_synchronized() - t # Statistics per image @@ -186,16 +186,16 @@ def test(data, stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # Plot images - if batch_i < 1: - f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename + if plots and batch_i < 1: + f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename plot_images(img, targets, paths, str(f), names) # ground truth - f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i) + f = save_dir / ('test_batch%g_pred.jpg' % batch_i) plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): - p, r, ap, f1, ap_class = ap_per_class(*stats) + p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png') p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95] mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class @@ -261,7 +261,6 @@ def test(data, parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') parser.add_argument('--augment', action='store_true', help='augmented inference') - parser.add_argument('--merge', action='store_true', help='use Merge NMS') parser.add_argument('--verbose', action='store_true', help='report mAP by class') parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') opt = parser.parse_args() diff --git a/train.py b/train.py index c6fa31107a02..4060a5701a8b 100644 --- a/train.py +++ b/train.py @@ -1,5 +1,4 @@ import argparse -import glob import logging import math import os @@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None): ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride']) final_epoch = epoch + 1 == epochs if not opt.notest or final_epoch: # Calculate mAP - if final_epoch: # replot predictions - [os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)] results, maps, times = test.test(opt.data, batch_size=total_batch_size, imgsz=imgsz_test, model=ema.ema, single_cls=opt.single_cls, dataloader=testloader, - save_dir=log_dir) + save_dir=log_dir, + plots=epoch == 0 or final_epoch) # plot first and last # Write with open(results_file, 'a') as f: diff --git a/utils/general.py b/utils/general.py index b540a6ef6879..da016631e589 100755 --- a/utils/general.py +++ b/utils/general.py @@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape): boxes[:, 3].clamp_(0, img_shape[0]) # y2 -def ap_per_class(tp, conf, pred_cls, target_cls): +def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'): """ Compute the average precision, given the recall and precision curves. Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. # Arguments - tp: True positives (nparray, nx1 or nx10). + tp: True positives (nparray, nx1 or nx10). conf: Objectness value from 0-1 (nparray). - pred_cls: Predicted object classes (nparray). - target_cls: True object classes (nparray). + pred_cls: Predicted object classes (nparray). + target_cls: True object classes (nparray). + plot: Plot precision-recall curve at mAP@0.5 + fname: Plot filename # Returns The average precision as computed in py-faster-rcnn. """ @@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls): unique_classes = np.unique(target_cls) # Create Precision-Recall curve and compute AP for each class + px, py = np.linspace(0, 1, 1000), [] # for plotting pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898 s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95) ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s) @@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls): p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score # AP from recall-precision curve + py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5 for j in range(tp.shape[1]): ap[ci, j] = compute_ap(recall[:, j], precision[:, j]) - # Plot - # fig, ax = plt.subplots(1, 1, figsize=(5, 5)) - # ax.plot(recall, precision) - # ax.set_xlabel('Recall') - # ax.set_ylabel('Precision') - # ax.set_xlim(0, 1.01) - # ax.set_ylim(0, 1.01) - # fig.tight_layout() - # fig.savefig('PR_curve.png', dpi=300) - # Compute F1 score (harmonic mean of precision and recall) f1 = 2 * p * r / (p + r + 1e-16) + if plot: + py = np.stack(py, axis=1) + fig, ax = plt.subplots(1, 1, figsize=(5, 5)) + ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision) + ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes') + ax.set_xlabel('Recall') + ax.set_ylabel('Precision') + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + plt.legend() + fig.tight_layout() + fig.savefig(fname, dpi=200) + return p, r, ap, f1, unique_classes.astype('int32') @@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods() def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16): tl = 3 # line thickness tf = max(tl - 1, 1) # font thickness - if os.path.isfile(fname): # do not overwrite - return None if isinstance(images, torch.Tensor): images = images.cpu().float().numpy()