Skip to content

Commit

Permalink
Reduce val device transfers (ultralytics#7525)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored and Clay Januhowski committed Sep 8, 2022
1 parent 000f3bf commit f6454c1
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions val.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,14 @@ def run(
# Metrics
for si, pred in enumerate(out):
labels = targets[targets[:, 0] == si, 1:]
nl = len(labels)
tcls = labels[:, 0].tolist() if nl else [] # target class
nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
path, shape = Path(paths[si]), shapes[si][0]
correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
seen += 1

if len(pred) == 0:
if npr == 0:
if nl:
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls))
stats.append((correct, *torch.zeros((3, 0))))
continue

# Predictions
Expand All @@ -244,9 +244,7 @@ def run(
correct = process_batch(predn, labelsn, iouv)
if plots:
confusion_matrix.process_batch(predn, labelsn)
else:
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool)
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls)
stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)

# Save/log
if save_txt:
Expand All @@ -265,7 +263,7 @@ def run(
callbacks.run('on_val_batch_end')

# Compute metrics
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
if len(stats) and stats[0].any():
tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
Expand Down

0 comments on commit f6454c1

Please sign in to comment.