Skip to content

Commit

Permalink
Ensure confusion matrix DDP reduction is before normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
jpblackburn committed Jul 27, 2020
1 parent 913c1e3 commit e3e0743
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
dice_score,
iou,
)
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric
from pytorch_lightning.metrics.metric import Metric, TensorMetric, TensorCollectionMetric


class Accuracy(TensorMetric):
Expand Down Expand Up @@ -74,11 +74,15 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
num_classes=self.num_classes, reduction=self.reduction)


class ConfusionMatrix(TensorMetric):
class ConfusionMatrix(Metric):
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
This metric uses an internal class to perform the confusion matrix computation and
any DDP reduction. The normalization is performed on the full confusion matrix
after the reduction.
Example:
>>> pred = torch.tensor([0, 1, 2, 2])
Expand All @@ -90,7 +94,6 @@ class ConfusionMatrix(TensorMetric):
[0., 0., 2.]])
"""

def __init__(
self,
normalize: bool = False,
Expand All @@ -105,11 +108,13 @@ def __init__(
reduce_op: the operation to perform for ddp reduction
num_classes: number of classes if known. Important for DDP reduction.
"""
super().__init__(name='confusion_matrix',
reduce_group=reduce_group,
reduce_op=reduce_op)
super().__init__(name='confusion_matrix')
self._metric = self._ConfusionMatrixInternal(
reduce_group=reduce_group,
reduce_op=reduce_op,
num_classes=num_classes
)
self.normalize = normalize
self.num_classes = num_classes

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -122,9 +127,35 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Return:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize,
num_classes=self.num_classes)
cm = self._metric(pred=pred, target=target)

if self.normalize:
row_sum = cm.sum(-1, keepdim=True)
divisor = torch.max(row_sum, torch.tensor(1.0, device=cm.device))
cm = cm / divisor

return cm

class _ConfusionMatrixInternal(TensorMetric):
"""
Internal confusion matrix class to perform the computation and any DDP reduction
prior to the confusion matrix normalization.
"""
def __init__(
self,
reduce_group: Any = None,
reduce_op: Any = None,
num_classes: Optional[int] = None
):
super().__init__(name='confusion_matrix_internal',
reduce_group=reduce_group,
reduce_op=reduce_op)
self.num_classes = num_classes

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return confusion_matrix(pred=pred, target=target,
normalize=False,
num_classes=self.num_classes)


class PrecisionRecall(TensorCollectionMetric):
Expand Down

0 comments on commit e3e0743

Please sign in to comment.