From 47ed8114c509d4b5c8fc0c30d68778c82228bcb7 Mon Sep 17 00:00:00 2001 From: Zigarss <32835472+Zigars@users.noreply.github.com> Date: Mon, 28 Jun 2021 17:56:06 +0800 Subject: [PATCH 1/7] Add feature map visualization Add a feature_visualization function to visualize the mid feature map of the model. --- utils/plots.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/utils/plots.py b/utils/plots.py index 66a30918190e..f1d5bc0a5067 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -15,6 +15,7 @@ import torch import yaml from PIL import Image, ImageDraw, ImageFont +from torchvision import transforms from utils.general import xywh2xyxy, xyxy2xywh from utils.metrics import fitness @@ -445,3 +446,39 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): ax[1].legend() fig.savefig(Path(save_dir) / 'results.png', dpi=200) + + +def feature_visualization(features, model_type, model_id, feature_num=64): + """ + features: The feature map which you need to visualization + model_type: The type of feature map + model_id: The id of feature map + feature_num: The amount of visualization you need + """ + save_dir = "features/" + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # print(features.shape) + # block by channel dimension + blocks = torch.chunk(features, features.shape[1], dim=1) + + # # size of feature + # size = features.shape[2], features.shape[3] + + plt.figure() + for i in range(feature_num): + torch.squeeze(blocks[i]) + feature = transforms.ToPILImage()(blocks[i].squeeze()) + # print(feature) + ax = plt.subplot(int(math.sqrt(feature_num)), int(math.sqrt(feature_num)), i+1) + ax.set_xticks([]) + ax.set_yticks([]) + + plt.imshow(feature) + # gray feature + # plt.imshow(feature, cmap='gray') + + # plt.show() + plt.savefig(save_dir + '{}_{}_feature_map_{}.png' + .format(model_type.split('.')[2], model_id, feature_num), dpi=300) From ada66fe402b34abd6bb7870312b48dd053889b61 Mon Sep 17 00:00:00 2001 From: Zigarss <32835472+Zigars@users.noreply.github.com> Date: Mon, 28 Jun 2021 18:10:33 +0800 Subject: [PATCH 2/7] Update yolo.py --- models/yolo.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/models/yolo.py b/models/yolo.py index 4a2514edd295..9952d8cdf514 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -17,6 +17,7 @@ from models.experimental import * from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging +from utils.plots import feature_visualization from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ select_device, copy_attr @@ -153,6 +154,11 @@ def forward_once(self, x, profile=False): x = m(x) # run y.append(x if m.i in self.save else None) # save output + + feature_vis = True + if m.type == 'models.common.SPP' and feature_vis: + print(m.type, m.i) + feature_visualization(x, m.type, m.i) if profile: logger.info('%.1fms total' % sum(dt)) From 00249bf2206abca02964520841a4241875392e4d Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:18:56 +0200 Subject: [PATCH 3/7] remove boolean from forward and reorder if statement --- models/yolo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/models/yolo.py b/models/yolo.py index 9952d8cdf514..1fa67b3b3fc1 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -136,7 +136,7 @@ def forward_augment(self, x): y.append(yi) return torch.cat(y, 1), None # augmented inference, train - def forward_once(self, x, profile=False): + def forward_once(self, x, profile=False, feature_vis=False): y, dt = [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer @@ -155,8 +155,7 @@ def forward_once(self, x, profile=False): x = m(x) # run y.append(x if m.i in self.save else None) # save output - feature_vis = True - if m.type == 'models.common.SPP' and feature_vis: + if feature_vis and m.type == 'models.common.SPP': print(m.type, m.i) feature_visualization(x, m.type, m.i) From 8d22cad8afc928c499c15a88e0bde4fbea8464f0 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:27:33 +0200 Subject: [PATCH 4/7] remove print from forward --- models/yolo.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/yolo.py b/models/yolo.py index 1fa67b3b3fc1..4c9456edd687 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -156,7 +156,6 @@ def forward_once(self, x, profile=False, feature_vis=False): y.append(x if m.i in self.save else None) # save output if feature_vis and m.type == 'models.common.SPP': - print(m.type, m.i) feature_visualization(x, m.type, m.i) if profile: From 7b62f38fa59f530ff72bcbb8ac9fee39a10106a3 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:39:56 +0200 Subject: [PATCH 5/7] General cleanup --- utils/plots.py | 43 +++++++++++++++++-------------------------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index f1d5bc0a5067..ea2f3474a205 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -17,7 +17,7 @@ from PIL import Image, ImageDraw, ImageFont from torchvision import transforms -from utils.general import xywh2xyxy, xyxy2xywh +from utils.general import increment_path, xywh2xyxy, xyxy2xywh from utils.metrics import fitness # Settings @@ -300,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) - # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 + # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 ax[0].set_ylabel('instances') if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) @@ -446,39 +446,30 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): ax[1].legend() fig.savefig(Path(save_dir) / 'results.png', dpi=200) - -def feature_visualization(features, model_type, model_id, feature_num=64): + +def feature_visualization(features, module_type, module_idx, n=64): """ - features: The feature map which you need to visualization - model_type: The type of feature map - model_id: The id of feature map - feature_num: The amount of visualization you need + features: Features to be visualized + module_type: module type + module_idx: module layer index within model + n: Maximum number of feature maps to plot """ - save_dir = "features/" - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - # print(features.shape) - # block by channel dimension - blocks = torch.chunk(features, features.shape[1], dim=1) - - # # size of feature - # size = features.shape[2], features.shape[3] + project, name = 'runs/features', 'exp' + save_dir = increment_path(Path(project) / name) # increment run + save_dir.mkdir(parents=True, exist_ok=True) # make dir plt.figure() - for i in range(feature_num): - torch.squeeze(blocks[i]) + blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension + n = min(n, len(blocks)) + for i in range(n): feature = transforms.ToPILImage()(blocks[i].squeeze()) - # print(feature) - ax = plt.subplot(int(math.sqrt(feature_num)), int(math.sqrt(feature_num)), i+1) + ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) ax.set_xticks([]) ax.set_yticks([]) - plt.imshow(feature) - # gray feature # plt.imshow(feature, cmap='gray') # plt.show() - plt.savefig(save_dir + '{}_{}_feature_map_{}.png' - .format(model_type.split('.')[2], model_id, feature_num), dpi=300) + plt.savefig(save_dir / f"{module_type.split('.')[2]}_{module_idx}_feature_map_{n}.png", dpi=300) + print(f'Features from {module_type} module in layer {module_idx} saved to {save_dir}') From cd003e9c5a8bef8276145a6297aae042fe5eb879 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 12:40:47 +0200 Subject: [PATCH 6/7] Indent --- utils/plots.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index ea2f3474a205..5f0c9d31d3c8 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -450,10 +450,10 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): def feature_visualization(features, module_type, module_idx, n=64): """ - features: Features to be visualized - module_type: module type - module_idx: module layer index within model - n: Maximum number of feature maps to plot + features: Features to be visualized + module_type: Module type + module_idx: Module layer index within model + n: Maximum number of feature maps to plot """ project, name = 'runs/features', 'exp' save_dir = increment_path(Path(project) / name) # increment run From e3c14b36d1ebaf28990b91ce4a9ed37f8a4ebef6 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 28 Jun 2021 13:15:44 +0200 Subject: [PATCH 7/7] Update plots.py --- utils/plots.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/utils/plots.py b/utils/plots.py index 5f0c9d31d3c8..36386371dbec 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -459,17 +459,15 @@ def feature_visualization(features, module_type, module_idx, n=64): save_dir = increment_path(Path(project) / name) # increment run save_dir.mkdir(parents=True, exist_ok=True) # make dir - plt.figure() + plt.figure(tight_layout=True) blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension n = min(n, len(blocks)) for i in range(n): feature = transforms.ToPILImage()(blocks[i].squeeze()) ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) - ax.set_xticks([]) - ax.set_yticks([]) - plt.imshow(feature) - # plt.imshow(feature, cmap='gray') - - # plt.show() - plt.savefig(save_dir / f"{module_type.split('.')[2]}_{module_idx}_feature_map_{n}.png", dpi=300) - print(f'Features from {module_type} module in layer {module_idx} saved to {save_dir}') + ax.axis('off') + plt.imshow(feature) # cmap='gray' + + f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png" + print(f'Saving {save_dir / f}...') + plt.savefig(save_dir / f, dpi=300)