From c2e8fbf0c2cac85b0dcd9c958a7a45a8a5926317 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Jul 2021 15:26:10 +0200 Subject: [PATCH 1/3] Feature visualization update --- detect.py | 6 +++++- models/yolo.py | 11 +++++------ utils/plots.py | 39 ++++++++++++++++++--------------------- 3 files changed, 28 insertions(+), 28 deletions(-) diff --git a/detect.py b/detect.py index a4542f7e8802..44b33eb42289 100644 --- a/detect.py +++ b/detect.py @@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) classes=None, # filter by class: --class 0, or --class 0 2 3 agnostic_nms=False, # class-agnostic NMS augment=False, # augmented inference + visualize=False, # visualize features update=False, # update all models project='runs/detect', # save results to project/name name='exp', # save results to project/name @@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s) # Inference t1 = time_synchronized() - pred = model(img, augment=augment)[0] + pred = model(img, + augment=augment, + visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0] # Apply NMS pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) @@ -201,6 +204,7 @@ def parse_opt(): parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3') parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS') parser.add_argument('--augment', action='store_true', help='augmented inference') + parser.add_argument('--visualize', action='store_true', help='visualize features') parser.add_argument('--update', action='store_true', help='update all models') parser.add_argument('--project', default='runs/detect', help='save results to project/name') parser.add_argument('--name', default='exp', help='save results to project/name') diff --git a/models/yolo.py b/models/yolo.py index 826590bd9783..b11443377080 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -117,11 +117,10 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i self.info() logger.info('') - def forward(self, x, augment=False, profile=False): + def forward(self, x, augment=False, profile=False, visualize=False): if augment: return self.forward_augment(x) # augmented inference, None - else: - return self.forward_once(x, profile) # single-scale inference, train + return self.forward_once(x, profile, visualize) # single-scale inference, train def forward_augment(self, x): img_size = x.shape[-2:] # height, width @@ -136,7 +135,7 @@ def forward_augment(self, x): y.append(yi) return torch.cat(y, 1), None # augmented inference, train - def forward_once(self, x, profile=False, feature_vis=False): + def forward_once(self, x, profile=False, visualize=False): y, dt = [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer @@ -155,8 +154,8 @@ def forward_once(self, x, profile=False, feature_vis=False): x = m(x) # run y.append(x if m.i in self.save else None) # save output - if feature_vis and m.type == 'models.common.SPP': - feature_visualization(x, m.type, m.i) + if visualize: + feature_visualization(x, m.type, m.i, save_dir=visualize) if profile: logger.info('%.1fms total' % sum(dt)) diff --git a/utils/plots.py b/utils/plots.py index 4b6c63992ac7..1ab3bb6f21fe 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -1,12 +1,12 @@ # Plotting utils import glob -import math import os from copy import copy from pathlib import Path import cv2 +import math import matplotlib import matplotlib.pyplot as plt import numpy as np @@ -15,7 +15,6 @@ import torch import yaml from PIL import Image, ImageDraw, ImageFont -from torchvision import transforms from utils.general import increment_path, xywh2xyxy, xyxy2xywh from utils.metrics import fitness @@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): fig.savefig(Path(save_dir) / 'results.png', dpi=200) -def feature_visualization(x, module_type, stage, n=64): +def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')): """ x: Features to be visualized module_type: Module type stage: Module stage within model n: Maximum number of feature maps to plot + save_dir: Directory to save results """ - batch, channels, height, width = x.shape # batch, channels, height, width - if height > 1 and width > 1: - project, name = 'runs/features', 'exp' - save_dir = increment_path(Path(project) / name) # increment run - save_dir.mkdir(parents=True, exist_ok=True) # make dir - - plt.figure(tight_layout=True) - blocks = torch.chunk(x, channels, dim=1) # block by channel dimension - n = min(n, len(blocks)) - for i in range(n): - feature = transforms.ToPILImage()(blocks[i].squeeze()) - ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) - ax.axis('off') - plt.imshow(feature) # cmap='gray' - - f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png" - print(f'Saving {save_dir / f}...') - plt.savefig(save_dir / f, dpi=300) + if 'Detect' not in module_type: + batch, channels, height, width = x.shape # batch, channels, height, width + if height > 1 and width > 1: + f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + + plt.figure(tight_layout=True) + blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels + n = min(n, channels) # number of plots + ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols + for i in range(n): + ax[i].imshow(blocks[i].squeeze()) # cmap='gray' + ax[i].axis('off') + + print(f'Saving {save_dir / f}... ({n}/{channels})') + plt.savefig(save_dir / f, dpi=300) From fcc4914e3689c2f195bd489e83853dbba84bbaba Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Jul 2021 15:39:51 +0200 Subject: [PATCH 2/3] Save to jpg (faster) --- utils/plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/plots.py b/utils/plots.py index 1ab3bb6f21fe..fc6b683aebde 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -458,7 +458,7 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec if 'Detect' not in module_type: batch, channels, height, width = x.shape # batch, channels, height, width if height > 1 and width > 1: - f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename + f = f"stage{stage}_{module_type.split('.')[-1]}_features.jpg" # filename plt.figure(tight_layout=True) blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels From 94bf4aa23d3d9ba64cc474fe88fc8bd47c2a0bf3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 7 Jul 2021 15:41:47 +0200 Subject: [PATCH 3/3] Save to png --- utils/plots.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/plots.py b/utils/plots.py index fc6b683aebde..1ab3bb6f21fe 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -458,7 +458,7 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec if 'Detect' not in module_type: batch, channels, height, width = x.shape # batch, channels, height, width if height > 1 and width > 1: - f = f"stage{stage}_{module_type.split('.')[-1]}_features.jpg" # filename + f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename plt.figure(tight_layout=True) blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels