From a364c097283cfcb76fea4cbd73089b7064962fb2 Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Fri, 11 Jun 2021 11:33:23 +0200 Subject: [PATCH] Add ConfusionMatrix `normalize=True` flag --- utils/metrics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/utils/metrics.py b/utils/metrics.py index 6b61d6d6ef02..09b994414ffc 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -158,11 +158,12 @@ def process_batch(self, detections, labels): def matrix(self): return self.matrix - def plot(self, save_dir='', names=()): + def plot(self, normalize=True, save_dir='', names=()): try: import seaborn as sn - - array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize + + if normalize: + array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize columns array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) fig = plt.figure(figsize=(12, 9), tight_layout=True)