From c09964c27cc275c8e32630715cca5be77078dae2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 19 Feb 2021 15:39:09 -0800 Subject: [PATCH] Update inference default to multi_label=False (#2252) * Update inference default to multi_label=False * bug fix * Update plots.py * Update plots.py --- models/common.py | 2 +- test.py | 8 ++++---- utils/general.py | 9 +++++---- utils/plots.py | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/models/common.py b/models/common.py index efcc6071af63..ad35f908d865 100644 --- a/models/common.py +++ b/models/common.py @@ -7,7 +7,7 @@ import requests import torch import torch.nn as nn -from PIL import Image, ImageDraw +from PIL import Image from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, xyxy2xywh diff --git a/test.py b/test.py index 738764f15601..c30148dfb2f1 100644 --- a/test.py +++ b/test.py @@ -106,7 +106,7 @@ def test(data, with torch.no_grad(): # Run model t = time_synchronized() - inf_out, train_out = model(img, augment=augment) # inference and training outputs + out, train_out = model(img, augment=augment) # inference and training outputs t0 += time_synchronized() - t # Compute loss @@ -117,11 +117,11 @@ def test(data, targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling t = time_synchronized() - output = non_max_suppression(inf_out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb) + out = non_max_suppression(out, conf_thres=conf_thres, iou_thres=iou_thres, labels=lb, multi_label=True) t1 += time_synchronized() - t # Statistics per image - for si, pred in enumerate(output): + for si, pred in enumerate(out): labels = targets[targets[:, 0] == si, 1:] nl = len(labels) tcls = labels[:, 0].tolist() if nl else [] # target class @@ -209,7 +209,7 @@ def test(data, f = save_dir / f'test_batch{batch_i}_labels.jpg' # labels Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start() f = save_dir / f'test_batch{batch_i}_pred.jpg' # predictions - Thread(target=plot_images, args=(img, output_to_target(output), paths, f, names), daemon=True).start() + Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start() # Compute statistics stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy diff --git a/utils/general.py b/utils/general.py index 64b360fbe7df..3b5f4629b00a 100755 --- a/utils/general.py +++ b/utils/general.py @@ -390,11 +390,12 @@ def wh_iou(wh1, wh2): return inter / (wh1.prod(2) + wh2.prod(2) - inter) # iou = inter / (area1 + area2 - inter) -def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): - """Performs Non-Maximum Suppression (NMS) on inference results +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, + labels=()): + """Runs Non-Maximum Suppression (NMS) on inference results Returns: - detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + list of detections, on (n,6) tensor per image [xyxy, conf, cls] """ nc = prediction.shape[2] - 5 # number of classes @@ -406,7 +407,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non max_nms = 30000 # maximum number of boxes into torchvision.ops.nms() time_limit = 10.0 # seconds to quit after redundant = True # require redundant detections - multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img) merge = False # use merge-NMS t = time.time() diff --git a/utils/plots.py b/utils/plots.py index 94f46a9a4026..aa9a1cab81f0 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -54,7 +54,7 @@ def butter_lowpass(cutoff, fs, order): return filtfilt(b, a, data) # forward-backward filter -def plot_one_box(x, img, color=None, label=None, line_thickness=None): +def plot_one_box(x, img, color=None, label=None, line_thickness=3): # Plots one bounding box on image img tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness color = color or [random.randint(0, 255) for _ in range(3)]