Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update test.py profiling #3555

Merged
merged 3 commits into from
Jun 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test(data,
plots=True,
wandb_logger=None,
compute_loss=None,
half_precision=True,
half=True,
opt=None):
# Initialize/load model and set device
training = model is not None
Expand All @@ -63,7 +63,7 @@ def test(data,
# model = nn.DataParallel(model)

# Half
half = device.type != 'cpu' and half_precision # half precision only supported on CUDA
half &= device.type != 'cpu' # half precision only supported on CUDA
if half:
model.half()

Expand Down Expand Up @@ -95,20 +95,22 @@ def test(data,
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0.
p, r, f1, mp, mr, map50, map, t0, t1, t2 = 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.
loss = torch.zeros(3, device=device)
jdict, stats, ap, ap_class, wandb_images = [], [], [], [], []
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
t_ = time_synchronized()
img = img.to(device, non_blocking=True)
img = img.half() if half else img.float() # uint8 to fp16/32
img /= 255.0 # 0 - 255 to 0.0 - 1.0
targets = targets.to(device)
nb, _, height, width = img.shape # batch size, channels, height, width
t = time_synchronized()
t0 += t - t_

# Run model
t = time_synchronized()
out, train_out = model(img, augment=augment) # inference and training outputs
t0 += time_synchronized() - t
t1 += time_synchronized() - t

# Compute loss
if compute_loss:
Expand All @@ -119,7 +121,7 @@ def test(data,
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
t = time_synchronized()
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
t1 += time_synchronized() - t
t2 += time_synchronized() - t

# Statistics per image
for si, pred in enumerate(out):
Expand Down Expand Up @@ -236,9 +238,10 @@ def test(data,
print(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))

# Print speeds
t = tuple(x / seen * 1E3 for x in (t0, t1, t0 + t1)) + (imgsz, imgsz, batch_size) # tuple
t = tuple(x / seen * 1E3 for x in (t0, t1, t2)) # speeds per image
if not training:
print('Speed: %.1f/%.1f/%.1f ms inference/NMS/total per %gx%g image at batch-size %g' % t)
shape = (batch_size, 3, imgsz, imgsz)
print(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)

# Plots
if plots:
Expand Down Expand Up @@ -327,24 +330,25 @@ def test(data,
save_txt=opt.save_txt | opt.save_hybrid,
save_hybrid=opt.save_hybrid,
save_conf=opt.save_conf,
half_precision=opt.half,
half=opt.half,
opt=opt
)

elif opt.task == 'speed': # speed benchmarks
for w in opt.weights:
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, opt=opt)
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
test(opt.data, w, opt.batch_size, opt.img_size, 0.25, 0.45, save_json=False, plots=False, half=True,
opt=opt)

elif opt.task == 'study': # run over a range of settings and save/plot
# python test.py --task study --data coco.yaml --iou 0.7 --weights yolov5s.pt yolov5m.pt yolov5l.pt yolov5x.pt
x = list(range(256, 1536 + 128, 128)) # x axis (image sizes)
for w in opt.weights:
for w in opt.weights if isinstance(opt.weights, list) else [opt.weights]:
f = f'study_{Path(opt.data).stem}_{Path(w).stem}.txt' # filename to save to
y = [] # y axis
for i in x: # img-size
print(f'\nRunning {f} point {i}...')
r, _, t = test(opt.data, w, opt.batch_size, i, opt.conf_thres, opt.iou_thres, opt.save_json,
plots=False, opt=opt)
plots=False, half=True, opt=opt)
y.append(r + t) # results and times
np.savetxt(f, y, fmt='%10.4g') # save
os.system('zip -r study.zip study_*.txt')
Expand Down
26 changes: 13 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def train(hyp, opt, device, tb_writer=None):
loggers['wandb'] = wandb_logger.wandb
data_dict = wandb_logger.data_dict
if wandb_logger.wandb:
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # may update weights, epochs if resuming

nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = ['item'] if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
Expand Down Expand Up @@ -354,18 +354,18 @@ def train(hyp, opt, device, tb_writer=None):
final_epoch = epoch + 1 == epochs
if not opt.notest or final_epoch: # Calculate mAP
wandb_logger.current_epoch = epoch + 1
results, maps, times = test.test(data_dict,
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)
results, maps, _ = test.test(data_dict,
batch_size=batch_size * 2,
imgsz=imgsz_test,
model=ema.ema,
single_cls=single_cls,
dataloader=testloader,
save_dir=save_dir,
save_json=is_coco and final_epoch,
verbose=nc < 50 and final_epoch,
plots=plots and final_epoch,
wandb_logger=wandb_logger,
compute_loss=compute_loss)

# Write
with open(results_file, 'a') as f:
Expand Down
17 changes: 9 additions & 8 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import glob
import math
import os
import random
from copy import copy
from pathlib import Path

Expand Down Expand Up @@ -252,21 +251,23 @@ def plot_targets_txt(): # from utils.plots import *; plot_targets_txt()

def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_txt()
# Plot study.txt generated by test.py
fig, ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)
# ax = ax.ravel()
plot2 = False # plot additional results
if plot2:
ax = plt.subplots(2, 4, figsize=(10, 6), tight_layout=True)[1].ravel()

fig2, ax2 = plt.subplots(1, 1, figsize=(8, 4), tight_layout=True)
# for f in [Path(path) / f'study_coco_{x}.txt' for x in ['yolov5s6', 'yolov5m6', 'yolov5l6', 'yolov5x6']]:
for f in sorted(Path(path).glob('study*.txt')):
y = np.loadtxt(f, dtype=np.float32, usecols=[0, 1, 2, 3, 7, 8, 9], ndmin=2).T
x = np.arange(y.shape[1]) if x is None else np.array(x)
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_inference (ms/img)', 't_NMS (ms/img)', 't_total (ms/img)']
# for i in range(7):
# ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
# ax[i].set_title(s[i])
if plot2:
s = ['P', 'R', 'mAP@.5', 'mAP@.5:.95', 't_preprocess (ms/img)', 't_inference (ms/img)', 't_NMS (ms/img)']
for i in range(7):
ax[i].plot(x, y[i], '.-', linewidth=2, markersize=8)
ax[i].set_title(s[i])

j = y[3].argmax() + 1
ax2.plot(y[6, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
ax2.plot(y[5, 1:j], y[3, 1:j] * 1E2, '.-', linewidth=2, markersize=8,
label=f.stem.replace('study_coco_', '').replace('yolo', 'YOLO'))

ax2.plot(1E3 / np.array([209, 140, 97, 58, 35, 18]), [34.6, 40.5, 43.0, 47.5, 49.7, 51.5],
Expand Down