diff --git a/models/yolo.py b/models/yolo.py index 4a2514edd295..4c9456edd687 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -17,6 +17,7 @@ from models.experimental import * from utils.autoanchor import check_anchor_order from utils.general import make_divisible, check_file, set_logging +from utils.plots import feature_visualization from utils.torch_utils import time_synchronized, fuse_conv_and_bn, model_info, scale_img, initialize_weights, \ select_device, copy_attr @@ -135,7 +136,7 @@ def forward_augment(self, x): y.append(yi) return torch.cat(y, 1), None # augmented inference, train - def forward_once(self, x, profile=False): + def forward_once(self, x, profile=False, feature_vis=False): y, dt = [], [] # outputs for m in self.model: if m.f != -1: # if not from previous layer @@ -153,6 +154,9 @@ def forward_once(self, x, profile=False): x = m(x) # run y.append(x if m.i in self.save else None) # save output + + if feature_vis and m.type == 'models.common.SPP': + feature_visualization(x, m.type, m.i) if profile: logger.info('%.1fms total' % sum(dt)) diff --git a/utils/plots.py b/utils/plots.py index 66a30918190e..36386371dbec 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -15,8 +15,9 @@ import torch import yaml from PIL import Image, ImageDraw, ImageFont +from torchvision import transforms -from utils.general import xywh2xyxy, xyxy2xywh +from utils.general import increment_path, xywh2xyxy, xyxy2xywh from utils.metrics import fitness # Settings @@ -299,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) - # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 + # [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195 ax[0].set_ylabel('instances') if 0 < len(names) < 30: ax[0].set_xticks(range(len(names))) @@ -445,3 +446,28 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''): ax[1].legend() fig.savefig(Path(save_dir) / 'results.png', dpi=200) + + +def feature_visualization(features, module_type, module_idx, n=64): + """ + features: Features to be visualized + module_type: Module type + module_idx: Module layer index within model + n: Maximum number of feature maps to plot + """ + project, name = 'runs/features', 'exp' + save_dir = increment_path(Path(project) / name) # increment run + save_dir.mkdir(parents=True, exist_ok=True) # make dir + + plt.figure(tight_layout=True) + blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension + n = min(n, len(blocks)) + for i in range(n): + feature = transforms.ToPILImage()(blocks[i].squeeze()) + ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1) + ax.axis('off') + plt.imshow(feature) # cmap='gray' + + f = f"layer_{module_idx}_{module_type.split('.')[-1]}_features.png" + print(f'Saving {save_dir / f}...') + plt.savefig(save_dir / f, dpi=300)