Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precision-Recall Curve Feature Addition #1107

Merged
merged 4 commits into from
Oct 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test(data,
verbose=False,
model=None,
dataloader=None,
save_dir='',
merge=False,
save_txt=False):
save_dir=Path(''), # for saving images
save_txt=False, # for auto-labelling
plots=True):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
Expand All @@ -41,15 +41,15 @@ def test(data,
else: # called directly
set_logging()
device = select_device(opt.device, batch_size=batch_size)
merge, save_txt = opt.merge, opt.save_txt # use Merge NMS, save *.txt labels
save_txt = opt.save_txt # save *.txt labels
if save_txt:
out = Path('inference/output')
if os.path.exists(out):
shutil.rmtree(out) # delete output folder
os.makedirs(out) # make new output folder

# Remove previous
for f in glob.glob(str(Path(save_dir) / 'test_batch*.jpg')):
for f in glob.glob(str(save_dir / 'test_batch*.jpg')):
os.remove(f)

# Load model
Expand Down Expand Up @@ -110,7 +110,7 @@ def test(data,

# Run NMS
t = time_synchronized()
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, merge=merge)
output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres)
t1 += time_synchronized() - t

# Statistics per image
Expand Down Expand Up @@ -186,16 +186,16 @@ def test(data,
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls))

# Plot images
if batch_i < 1:
f = Path(save_dir) / ('test_batch%g_gt.jpg' % batch_i) # filename
if plots and batch_i < 1:
f = save_dir / ('test_batch%g_gt.jpg' % batch_i) # filename
plot_images(img, targets, paths, str(f), names) # ground truth
f = Path(save_dir) / ('test_batch%g_pred.jpg' % batch_i)
f = save_dir / ('test_batch%g_pred.jpg' % batch_i)
plot_images(img, output_to_target(output, width, height), paths, str(f), names) # predictions

# Compute statistics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
p, r, ap, f1, ap_class = ap_per_class(*stats)
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, fname=save_dir / 'precision-recall_curve.png')
p, r, ap50, ap = p[:, 0], r[:, 0], ap[:, 0], ap.mean(1) # [P, R, AP@0.5, AP@0.5:0.95]
mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
nt = np.bincount(stats[3].astype(np.int64), minlength=nc) # number of targets per class
Expand Down Expand Up @@ -261,7 +261,6 @@ def test(data,
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--merge', action='store_true', help='use Merge NMS')
parser.add_argument('--verbose', action='store_true', help='report mAP by class')
parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
opt = parser.parse_args()
Expand Down
6 changes: 2 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import glob
import logging
import math
import os
Expand Down Expand Up @@ -309,15 +308,14 @@ def train(hyp, opt, device, tb_writer=None):
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride'])
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
if final_epoch: # replot predictions
[os.remove(x) for x in glob.glob(str(log_dir / 'test_batch*_pred.jpg')) if os.path.exists(x)]
results, maps, times = test.test(opt.data,
batch_size=total_batch_size,
imgsz=imgsz_test,
model=ema.ema,
single_cls=opt.single_cls,
dataloader=testloader,
save_dir=log_dir)
save_dir=log_dir,
plots=epoch == 0 or final_epoch) # plot first and last

# Write
with open(results_file, 'a') as f:
Expand Down
37 changes: 21 additions & 16 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,14 +245,16 @@ def clip_coords(boxes, img_shape):
boxes[:, 3].clamp_(0, img_shape[0]) # y2


def ap_per_class(tp, conf, pred_cls, target_cls):
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, fname='precision-recall_curve.png'):
""" Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
tp: True positives (nparray, nx1 or nx10).
conf: Objectness value from 0-1 (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
pred_cls: Predicted object classes (nparray).
target_cls: True object classes (nparray).
plot: Plot precision-recall curve at mAP@0.5
fname: Plot filename
# Returns
The average precision as computed in py-faster-rcnn.
"""
Expand All @@ -265,6 +267,7 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
unique_classes = np.unique(target_cls)

# Create Precision-Recall curve and compute AP for each class
px, py = np.linspace(0, 1, 1000), [] # for plotting
pr_score = 0.1 # score to evaluate P and R https://github.com/ultralytics/yolov3/issues/898
s = [unique_classes.shape[0], tp.shape[1]] # number class, number iou thresholds (i.e. 10 for mAP0.5...0.95)
ap, p, r = np.zeros(s), np.zeros(s), np.zeros(s)
Expand All @@ -289,22 +292,26 @@ def ap_per_class(tp, conf, pred_cls, target_cls):
p[ci] = np.interp(-pr_score, -conf[i], precision[:, 0]) # p at pr_score

# AP from recall-precision curve
py.append(np.interp(px, recall[:, 0], precision[:, 0])) # precision at mAP@0.5
for j in range(tp.shape[1]):
ap[ci, j] = compute_ap(recall[:, j], precision[:, j])

# Plot
# fig, ax = plt.subplots(1, 1, figsize=(5, 5))
# ax.plot(recall, precision)
# ax.set_xlabel('Recall')
# ax.set_ylabel('Precision')
# ax.set_xlim(0, 1.01)
# ax.set_ylim(0, 1.01)
# fig.tight_layout()
# fig.savefig('PR_curve.png', dpi=300)

# Compute F1 score (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + 1e-16)

if plot:
py = np.stack(py, axis=1)
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.plot(px, py, linewidth=0.5, color='grey') # plot(recall, precision)
ax.plot(px, py.mean(1), linewidth=2, color='blue', label='all classes')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
plt.legend()
fig.tight_layout()
fig.savefig(fname, dpi=200)

return p, r, ap, f1, unique_classes.astype('int32')


Expand Down Expand Up @@ -1011,8 +1018,6 @@ def plot_wh_methods(): # from utils.general import *; plot_wh_methods()
def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=640, max_subplots=16):
tl = 3 # line thickness
tf = max(tl - 1, 1) # font thickness
if os.path.isfile(fname): # do not overwrite
return None

if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
Expand Down