Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature visualization update #3920

Merged
merged 3 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def run(weights='yolov5s.pt', # model.pt path(s)
classes=None, # filter by class: --class 0, or --class 0 2 3
agnostic_nms=False, # class-agnostic NMS
augment=False, # augmented inference
visualize=False, # visualize features
update=False, # update all models
project='runs/detect', # save results to project/name
name='exp', # save results to project/name
Expand Down Expand Up @@ -100,7 +101,9 @@ def run(weights='yolov5s.pt', # model.pt path(s)

# Inference
t1 = time_synchronized()
pred = model(img, augment=augment)[0]
pred = model(img,
augment=augment,
visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0]

# Apply NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
Expand Down Expand Up @@ -201,6 +204,7 @@ def parse_opt():
parser.add_argument('--classes', nargs='+', type=int, help='filter by class: --class 0, or --class 0 2 3')
parser.add_argument('--agnostic-nms', action='store_true', help='class-agnostic NMS')
parser.add_argument('--augment', action='store_true', help='augmented inference')
parser.add_argument('--visualize', action='store_true', help='visualize features')
parser.add_argument('--update', action='store_true', help='update all models')
parser.add_argument('--project', default='runs/detect', help='save results to project/name')
parser.add_argument('--name', default='exp', help='save results to project/name')
Expand Down
11 changes: 5 additions & 6 deletions models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
self.info()
logger.info('')

def forward(self, x, augment=False, profile=False):
def forward(self, x, augment=False, profile=False, visualize=False):
if augment:
return self.forward_augment(x) # augmented inference, None
else:
return self.forward_once(x, profile) # single-scale inference, train
return self.forward_once(x, profile, visualize) # single-scale inference, train

def forward_augment(self, x):
img_size = x.shape[-2:] # height, width
Expand All @@ -136,7 +135,7 @@ def forward_augment(self, x):
y.append(yi)
return torch.cat(y, 1), None # augmented inference, train

def forward_once(self, x, profile=False, feature_vis=False):
def forward_once(self, x, profile=False, visualize=False):
y, dt = [], [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
Expand All @@ -155,8 +154,8 @@ def forward_once(self, x, profile=False, feature_vis=False):
x = m(x) # run
y.append(x if m.i in self.save else None) # save output

if feature_vis and m.type == 'models.common.SPP':
feature_visualization(x, m.type, m.i)
if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize)

if profile:
logger.info('%.1fms total' % sum(dt))
Expand Down
39 changes: 18 additions & 21 deletions utils/plots.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Plotting utils

import glob
import math
import os
from copy import copy
from pathlib import Path

import cv2
import math
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -15,7 +15,6 @@
import torch
import yaml
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness
Expand Down Expand Up @@ -448,28 +447,26 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(x, module_type, stage, n=64):
def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')):
"""
x: Features to be visualized
module_type: Module type
stage: Module stage within model
n: Maximum number of feature maps to plot
save_dir: Directory to save results
"""
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure(tight_layout=True)
blocks = torch.chunk(x, channels, dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.axis('off')
plt.imshow(feature) # cmap='gray'

f = f"stage_{stage}_{module_type.split('.')[-1]}_features.png"
print(f'Saving {save_dir / f}...')
plt.savefig(save_dir / f, dpi=300)
if 'Detect' not in module_type:
batch, channels, height, width = x.shape # batch, channels, height, width
if height > 1 and width > 1:
f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename

plt.figure(tight_layout=True)
blocks = torch.chunk(x[0], channels, dim=0) # select batch index 0, block by channels
n = min(n, channels) # number of plots
ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols
for i in range(n):
ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
ax[i].axis('off')

print(f'Saving {save_dir / f}... ({n}/{channels})')
plt.savefig(save_dir / f, dpi=300)