Skip to content

Commit

Permalink
Increase plot_labels() speed (#1736)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed Dec 19, 2020
1 parent 49abc72 commit 685d601
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 18 deletions.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
# model._initialize_biases(cf.to(device))
if plots:
Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start()
plot_labels(labels, save_dir, loggers)
if tb_writer:
tb_writer.add_histogram('classes', c, 0)

Expand Down
26 changes: 9 additions & 17 deletions utils/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import yaml
from PIL import Image, ImageDraw
Expand Down Expand Up @@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx

def plot_labels(labels, save_dir=Path(''), loggers=None):
# plot dataset labels
print('Plotting labels... ')
c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
nc = int(c.max() + 1) # number of classes
colors = color_list()
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])

# seaborn correlogram
try:
import seaborn as sns
import pandas as pd
x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height'])
sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o',
plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02),
diag_kws=dict(bins=50))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()
except Exception as e:
pass
sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200)
plt.close()

# matplotlib labels
matplotlib.use('svg') # faster
ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
ax[0].set_xlabel('classes')
ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet')
ax[2].set_xlabel('x')
ax[2].set_ylabel('y')
ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet')
ax[3].set_xlabel('width')
ax[3].set_ylabel('height')
sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9)
sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9)

# rectangles
labels[:, 1:3] = 0.5 # center
Expand Down

0 comments on commit 685d601

Please sign in to comment.