From 248504cf13c2cba9e211e6110089a3e6f916109c Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 9 Jul 2021 15:23:02 +0200 Subject: [PATCH] Feature visualization improvements 32 (#3947) --- detect.py | 2 +- utils/plots.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/detect.py b/detect.py index 44b33eb42289..be2c5969c6d7 100644 --- a/detect.py +++ b/detect.py @@ -103,7 +103,7 @@ def run(weights='yolov5s.pt', # model.pt path(s) t1 = time_synchronized() pred = model(img, augment=augment, - visualize=increment_path(save_dir / 'features', mkdir=True) if visualize else False)[0] + visualize=increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False)[0] # Apply NMS pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det) diff --git a/utils/plots.py b/utils/plots.py index 23a48620e6b5..4e6b001dcc2f 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -16,7 +16,7 @@ import yaml from PIL import Image, ImageDraw, ImageFont -from utils.general import increment_path, xywh2xyxy, xyxy2xywh +from utils.general import xywh2xyxy, xyxy2xywh from utils.metrics import fitness # Settings @@ -447,7 +447,7 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): fig.savefig(Path(save_dir) / 'results.png', dpi=200) -def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detect/exp')): +def feature_visualization(x, module_type, stage, n=32, save_dir=Path('runs/detect/exp')): """ x: Features to be visualized module_type: Module type @@ -460,13 +460,14 @@ def feature_visualization(x, module_type, stage, n=64, save_dir=Path('runs/detec if height > 1 and width > 1: f = f"stage{stage}_{module_type.split('.')[-1]}_features.png" # filename - plt.figure(tight_layout=True) blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels n = min(n, channels) # number of plots - ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True)[1].ravel() # 8 rows x n/8 cols + fig, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols + ax = ax.ravel() + plt.subplots_adjust(wspace=0.05, hspace=0.05) for i in range(n): ax[i].imshow(blocks[i].squeeze()) # cmap='gray' ax[i].axis('off') print(f'Saving {save_dir / f}... ({n}/{channels})') - plt.savefig(save_dir / f, dpi=300) + plt.savefig(save_dir / f, dpi=300, bbox_inches='tight')