From f6454c116528b28110f7d003c07c83868e4d617f Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 21 Apr 2022 20:06:57 -0700 Subject: [PATCH] Reduce val device transfers (#7525) --- val.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/val.py b/val.py index a773ff3e4fa3..b2b3bc75911e 100644 --- a/val.py +++ b/val.py @@ -220,14 +220,14 @@ def run( # Metrics for si, pred in enumerate(out): labels = targets[targets[:, 0] == si, 1:] - nl = len(labels) - tcls = labels[:, 0].tolist() if nl else [] # target class + nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions path, shape = Path(paths[si]), shapes[si][0] + correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init seen += 1 - if len(pred) == 0: + if npr == 0: if nl: - stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) + stats.append((correct, *torch.zeros((3, 0)))) continue # Predictions @@ -244,9 +244,7 @@ def run( correct = process_batch(predn, labelsn, iouv) if plots: confusion_matrix.process_batch(predn, labelsn) - else: - correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool) - stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) # (correct, conf, pcls, tcls) + stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls) # Save/log if save_txt: @@ -265,7 +263,7 @@ def run( callbacks.run('on_val_batch_end') # Compute metrics - stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy + stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy if len(stats) and stats[0].any(): tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95