From e8c52374035fd2fb5a0b0029eaa5e5705186df17 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Jun 2021 11:46:05 +0200 Subject: [PATCH] ConfusionMatrix `normalize=True` fix (#3587) --- utils/metrics.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/utils/metrics.py b/utils/metrics.py index 09b994414ffc..8512197956e7 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -161,9 +161,8 @@ def matrix(self): def plot(self, normalize=True, save_dir='', names=()): try: import seaborn as sn - - if normalize: - array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns + + array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-6) if normalize else 1) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) fig = plt.figure(figsize=(12, 9), tight_layout=True) @@ -178,7 +177,7 @@ def plot(self, normalize=True, save_dir='', names=()): fig.axes[0].set_ylabel('Predicted') fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) except Exception as e: - pass + print(f'WARNING: ConfusionMatrix plot failure: {e}') def print(self): for i in range(self.nc + 1):