Skip to content

Commit

Permalink
Fix normalization of confusion matrices for empty rows
Browse files Browse the repository at this point in the history
  • Loading branch information
jpblackburn committed Jul 27, 2020
1 parent 3f2c102 commit 8b8b635
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def confusion_matrix(
cm = bins.reshape(num_classes, num_classes).squeeze().float()

if normalize:
cm = cm / cm.sum(-1)
row_sum = cm.sum(-1, keepdim=True)
divisor = torch.max(row_sum, torch.tensor(1.0, device=cm.device))
cm = cm / divisor

return cm

Expand Down

0 comments on commit 8b8b635

Please sign in to comment.