diff --git a/train.py b/train.py index 49bd5d6b27bf..8234025212b8 100644 --- a/train.py +++ b/train.py @@ -205,7 +205,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency # model._initialize_biases(cf.to(device)) if plots: - Thread(target=plot_labels, args=(labels, save_dir, loggers), daemon=True).start() + plot_labels(labels, save_dir, loggers) if tb_writer: tb_writer.add_histogram('classes', c, 0) diff --git a/utils/plots.py b/utils/plots.py index e3c981b4fe0c..3a4dccdc34c5 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -11,6 +11,8 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import seaborn as sns import torch import yaml from PIL import Image, ImageDraw @@ -253,34 +255,24 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx def plot_labels(labels, save_dir=Path(''), loggers=None): # plot dataset labels + print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes nc = int(c.max() + 1) # number of classes colors = color_list() + x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) # seaborn correlogram - try: - import seaborn as sns - import pandas as pd - x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) - sns.pairplot(x, corner=True, diag_kind='hist', kind='scatter', markers='o', - plot_kws=dict(s=3, edgecolor=None, linewidth=1, alpha=0.02), - diag_kws=dict(bins=50)) - plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) - plt.close() - except Exception as e: - pass + sns.pairplot(x, corner=True, diag_kind='auto', kind='hist', diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9)) + plt.savefig(save_dir / 'labels_correlogram.jpg', dpi=200) + plt.close() # matplotlib labels matplotlib.use('svg') # faster ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) ax[0].set_xlabel('classes') - ax[2].scatter(b[0], b[1], c=hist2d(b[0], b[1], 90), cmap='jet') - ax[2].set_xlabel('x') - ax[2].set_ylabel('y') - ax[3].scatter(b[2], b[3], c=hist2d(b[2], b[3], 90), cmap='jet') - ax[3].set_xlabel('width') - ax[3].set_ylabel('height') + sns.histplot(x, x='x', y='y', ax=ax[2], bins=50, pmax=0.9) + sns.histplot(x, x='width', y='height', ax=ax[3], bins=50, pmax=0.9) # rectangles labels[:, 1:3] = 0.5 # center