Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Confusion matrix #1474

Merged
merged 79 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
7b4fa5d
initial commit
glenn-jocher Nov 22, 2020
a156d41
add plotting
glenn-jocher Nov 22, 2020
8f4a71d
matrix to cpu
glenn-jocher Nov 22, 2020
08c7303
bug fix
glenn-jocher Nov 22, 2020
9ddafbd
update plot
glenn-jocher Nov 22, 2020
5e04f4a
update plot
glenn-jocher Nov 22, 2020
20cb0f9
update plot
glenn-jocher Nov 22, 2020
85febbe
update plot
glenn-jocher Nov 22, 2020
1da707b
update plot
glenn-jocher Nov 22, 2020
cdfbcd8
update plot
glenn-jocher Nov 22, 2020
abc018f
update plot
glenn-jocher Nov 22, 2020
88c5d72
update plot
glenn-jocher Nov 22, 2020
8a4c2ac
update plot
glenn-jocher Nov 22, 2020
5ed4547
update plot
glenn-jocher Nov 22, 2020
66be6d4
update plot
glenn-jocher Nov 22, 2020
4128c53
update plot
glenn-jocher Nov 22, 2020
3bca0ea
cleanup
glenn-jocher Nov 22, 2020
6c4e74f
cleanup
glenn-jocher Nov 22, 2020
73b7ab3
cleanup
glenn-jocher Nov 22, 2020
0f8d209
cleanup
glenn-jocher Nov 22, 2020
08fc9f2
cleanup
glenn-jocher Nov 22, 2020
53aa594
cleanup
glenn-jocher Nov 22, 2020
e0b6cab
cleanup
glenn-jocher Nov 22, 2020
28f0c8f
cleanup
glenn-jocher Nov 22, 2020
9ffc8f2
cleanup
glenn-jocher Nov 22, 2020
64ac278
cleanup
glenn-jocher Nov 22, 2020
e15bd16
cleanup
glenn-jocher Nov 22, 2020
a9de0dd
cleanup
glenn-jocher Nov 22, 2020
34b1479
cleanup
glenn-jocher Nov 22, 2020
e05efc6
seaborn pandas to requirements.txt
glenn-jocher Nov 22, 2020
b552b43
seaborn pandas to requirements.txt
glenn-jocher Nov 22, 2020
b602ad2
update wandb plotting
glenn-jocher Nov 22, 2020
0837504
remove pandas
glenn-jocher Nov 22, 2020
05d456d
if plots
glenn-jocher Nov 22, 2020
a846da1
if plots
glenn-jocher Nov 22, 2020
36c6cdb
if plots
glenn-jocher Nov 22, 2020
4faf5dc
if plots
glenn-jocher Nov 22, 2020
cdc2b18
if plots
glenn-jocher Nov 22, 2020
6ac4bff
initial commit
glenn-jocher Nov 22, 2020
3f1538c
add plotting
glenn-jocher Nov 22, 2020
556c7f4
matrix to cpu
glenn-jocher Nov 22, 2020
35f182a
bug fix
glenn-jocher Nov 22, 2020
d3f997f
update plot
glenn-jocher Nov 22, 2020
50ceabb
update plot
glenn-jocher Nov 22, 2020
24fde43
update plot
glenn-jocher Nov 22, 2020
130cfc9
update plot
glenn-jocher Nov 22, 2020
8004b45
update plot
glenn-jocher Nov 22, 2020
0fd0b43
update plot
glenn-jocher Nov 22, 2020
860cce7
update plot
glenn-jocher Nov 22, 2020
9a9845a
update plot
glenn-jocher Nov 22, 2020
a71a9a4
update plot
glenn-jocher Nov 22, 2020
3a9ffe7
update plot
glenn-jocher Nov 22, 2020
370537a
update plot
glenn-jocher Nov 22, 2020
9edbe6b
update plot
glenn-jocher Nov 22, 2020
526dfd3
cleanup
glenn-jocher Nov 22, 2020
76b74f3
cleanup
glenn-jocher Nov 22, 2020
331c002
cleanup
glenn-jocher Nov 22, 2020
69c6b59
cleanup
glenn-jocher Nov 22, 2020
8398cd5
cleanup
glenn-jocher Nov 22, 2020
df01a8e
cleanup
glenn-jocher Nov 22, 2020
b1f7e70
cleanup
glenn-jocher Nov 22, 2020
7fa0ac0
cleanup
glenn-jocher Nov 22, 2020
9ae0451
cleanup
glenn-jocher Nov 22, 2020
40b25bc
cleanup
glenn-jocher Nov 22, 2020
c146030
cleanup
glenn-jocher Nov 22, 2020
d73dcd3
cleanup
glenn-jocher Nov 22, 2020
9e2c3f9
cleanup
glenn-jocher Nov 22, 2020
cdc3ec7
seaborn pandas to requirements.txt
glenn-jocher Nov 22, 2020
8c4901b
seaborn pandas to requirements.txt
glenn-jocher Nov 22, 2020
dbff8b5
update wandb plotting
glenn-jocher Nov 22, 2020
8647b21
remove pandas
glenn-jocher Nov 22, 2020
658b2ef
if plots
glenn-jocher Nov 22, 2020
fb8dfaa
if plots
glenn-jocher Nov 22, 2020
3c3c780
if plots
glenn-jocher Nov 22, 2020
9ea0612
if plots
glenn-jocher Nov 22, 2020
7ec84d4
if plots
glenn-jocher Nov 22, 2020
f07630c
Merge remote-tracking branch 'origin/confusion_matrix' into confusion…
glenn-jocher Nov 23, 2020
3dab081
Cat apriori to autolabels
glenn-jocher Nov 23, 2020
8bbe0b5
cleanup
glenn-jocher Nov 23, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ tqdm>=4.41.0
# logging -------------------------------------
# wandb

# coco ----------------------------------------
# pycocotools>=2.0
# plotting ------------------------------------
seaborn
pandas

# export --------------------------------------
# coremltools==4.0
Expand All @@ -26,4 +27,4 @@ tqdm>=4.41.0

# extras --------------------------------------
# thop # FLOPS computation
# seaborn # plotting
# pycocotools>=2.0 # COCO mAP
15 changes: 10 additions & 5 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from utils.general import coco80_to_coco91_class, check_dataset, check_file, check_img_size, box_iou, \
non_max_suppression, scale_coords, xyxy2xywh, xywh2xyxy, set_logging, increment_path
from utils.loss import compute_loss
from utils.metrics import ap_per_class
from utils.metrics import ap_per_class, ConfusionMatrix
from utils.plots import plot_images, output_to_target
from utils.torch_utils import select_device, time_synchronized

Expand Down Expand Up @@ -89,6 +89,7 @@ def test(data,
dataloader = create_dataloader(path, imgsz, batch_size, model.stride.max(), opt, pad=0.5, rect=True)[0]

seen = 0
confusion_matrix = ConfusionMatrix(nc=nc)
names = {k: v for k, v in enumerate(model.names if hasattr(model, 'names') else model.module.names)}
coco91class = coco80_to_coco91_class()
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
Expand Down Expand Up @@ -176,6 +177,8 @@ def test(data,
# target boxes
tbox = xywh2xyxy(labels[:, 1:5])
scale_coords(img[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels
if plots:
confusion_matrix.process_batch(pred, torch.cat((labels[:, 0:1], tbox), 1))

# Per target class
for cls in torch.unique(tcls_tensor):
Expand Down Expand Up @@ -218,10 +221,12 @@ def test(data,
else:
nt = torch.zeros(1)

# W&B logging
if plots and wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(x), caption=x.name) for x in sorted(save_dir.glob('test*.jpg'))]})
# Plots
if plots:
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
if wandb and wandb.run:
wandb.log({"Images": wandb_images})
wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in sorted(save_dir.glob('test*.jpg'))]})

# Print results
pf = '%20s' + '%12.3g' * 6 # print format
Expand Down
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None):
if plots:
plot_results(save_dir=save_dir) # save as results.png
if wandb:
wandb.log({"Results": [wandb.Image(str(save_dir / x), caption=x) for x in
['results.png', 'precision_recall_curve.png']]})
files = ['results.png', 'precision_recall_curve.png', 'confusion_matrix.png']
wandb.log({"Results": [wandb.Image(str(save_dir / f), caption=f) for f in files
if (save_dir / f).exists()]})
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600))
else:
dist.destroy_process_group()
Expand Down
81 changes: 81 additions & 0 deletions utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import matplotlib.pyplot as plt
import numpy as np
import torch

from . import general


def fitness(x):
Expand Down Expand Up @@ -102,6 +105,84 @@ def compute_ap(recall, precision):
return ap, mpre, mrec


class ConfusionMatrix:
# Updated version of https://github.com/kaanakan/object_detection_confusion_matrix
def __init__(self, nc, conf=0.25, iou_thres=0.45):
self.matrix = np.zeros((nc + 1, nc + 1))
self.nc = nc # number of classes
self.conf = conf
self.iou_thres = iou_thres

def process_batch(self, detections, labels):
"""
Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]), x1, y1, x2, y2, conf, class
labels (Array[M, 5]), class, x1, y1, x2, y2
Returns:
None, updates confusion matrix accordingly
"""
detections = detections[detections[:, 4] > self.conf]
gt_classes = labels[:, 0].int()
detection_classes = detections[:, 5].int()
iou = general.box_iou(labels[:, 1:], detections[:, :4])

x = torch.where(iou > self.iou_thres)
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
else:
matches = np.zeros((0, 3))

n = matches.shape[0] > 0
m0, m1, _ = matches.transpose().astype(np.int16)
for i, gc in enumerate(gt_classes):
j = m0 == i
if n and sum(j) == 1:
self.matrix[gc, detection_classes[m1[j]]] += 1 # correct
else:
self.matrix[gc, self.nc] += 1 # background FP

if n:
for i, dc in enumerate(detection_classes):
if not any(m1 == i):
self.matrix[self.nc, dc] += 1 # background FN

def matrix(self):
return self.matrix

def plot(self, save_dir='', names=()):
try:
import seaborn as sn

array = self.matrix / (self.matrix.sum(0).reshape(1, self.nc + 1) + 1E-6) # normalize
array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)

fig = plt.figure(figsize=(12, 9))
sn.set(font_scale=1.0 if self.nc < 50 else 0.8) # for label size
labels = (0 < len(names) < 99) and len(names) == self.nc # apply names to ticklabels
sn.heatmap(array, annot=self.nc < 30, annot_kws={"size": 8}, cmap='Blues', fmt='.2f', square=True,
xticklabels=names + ['background FN'] if labels else "auto",
yticklabels=names + ['background FP'] if labels else "auto").set_facecolor((1, 1, 1))
fig.axes[0].set_xlabel('True')
fig.axes[0].set_ylabel('Predicted')
fig.tight_layout()
fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
except Exception as e:
pass

def print(self):
for i in range(self.nc + 1):
print(' '.join(map(str, self.matrix[i])))


# Plots ----------------------------------------------------------------------------------------------------------------

def plot_pr_curve(px, py, ap, save_dir='.', names=()):
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
py = np.stack(py, axis=1)
Expand Down