Skip to content

Commit

Permalink
General cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Jun 28, 2021
1 parent 8d22cad commit 7b62f38
Showing 1 changed file with 17 additions and 26 deletions.
43 changes: 17 additions & 26 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from PIL import Image, ImageDraw, ImageFont
from torchvision import transforms

from utils.general import xywh2xyxy, xyxy2xywh
from utils.general import increment_path, xywh2xyxy, xyxy2xywh
from utils.metrics import fitness

# Settings
Expand Down Expand Up @@ -300,7 +300,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
# [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # update colors bug #3195
ax[0].set_ylabel('instances')
if 0 < len(names) < 30:
ax[0].set_xticks(range(len(names)))
Expand Down Expand Up @@ -446,39 +446,30 @@ def plot_results(start=0, stop=0, bucket='', id=(), labels=(), save_dir=''):

ax[1].legend()
fig.savefig(Path(save_dir) / 'results.png', dpi=200)


def feature_visualization(features, model_type, model_id, feature_num=64):

def feature_visualization(features, module_type, module_idx, n=64):
"""
features: The feature map which you need to visualization
model_type: The type of feature map
model_id: The id of feature map
feature_num: The amount of visualization you need
features: Features to be visualized
module_type: module type
module_idx: module layer index within model
n: Maximum number of feature maps to plot
"""
save_dir = "features/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)

# print(features.shape)
# block by channel dimension
blocks = torch.chunk(features, features.shape[1], dim=1)

# # size of feature
# size = features.shape[2], features.shape[3]
project, name = 'runs/features', 'exp'
save_dir = increment_path(Path(project) / name) # increment run
save_dir.mkdir(parents=True, exist_ok=True) # make dir

plt.figure()
for i in range(feature_num):
torch.squeeze(blocks[i])
blocks = torch.chunk(features, features.shape[1], dim=1) # block by channel dimension
n = min(n, len(blocks))
for i in range(n):
feature = transforms.ToPILImage()(blocks[i].squeeze())
# print(feature)
ax = plt.subplot(int(math.sqrt(feature_num)), int(math.sqrt(feature_num)), i+1)
ax = plt.subplot(int(math.sqrt(n)), int(math.sqrt(n)), i + 1)
ax.set_xticks([])
ax.set_yticks([])

plt.imshow(feature)
# gray feature
# plt.imshow(feature, cmap='gray')

# plt.show()
plt.savefig(save_dir + '{}_{}_feature_map_{}.png'
.format(model_type.split('.')[2], model_id, feature_num), dpi=300)
plt.savefig(save_dir / f"{module_type.split('.')[2]}_{module_idx}_feature_map_{n}.png", dpi=300)
print(f'Features from {module_type} module in layer {module_idx} saved to {save_dir}')

0 comments on commit 7b62f38

Please sign in to comment.