Skip to content

Commit

Permalink
Allow user of ConfusionMatrix to specify number of classes
Browse files Browse the repository at this point in the history
  • Loading branch information
jpblackburn committed Jul 27, 2020
1 parent 8b8b635 commit 83927a9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
normalize: bool = False,
reduce_group: Any = None,
reduce_op: Any = None,
num_classes: Optional[int] = None,
):
"""
Args:
Expand All @@ -107,6 +108,7 @@ def __init__(
reduce_group=reduce_group,
reduce_op=reduce_op)
self.normalize = normalize
self.num_classes = num_classes

def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -120,7 +122,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
A Tensor with the confusion matrix.
"""
return confusion_matrix(pred=pred, target=target,
normalize=self.normalize)
normalize=self.normalize,
num_classes=self.num_classes)


class PrecisionRecall(TensorCollectionMetric):
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def confusion_matrix(
pred: torch.Tensor,
target: torch.Tensor,
normalize: bool = False,
num_classes: Optional[int] = None
) -> torch.Tensor:
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
Expand All @@ -258,7 +259,7 @@ def confusion_matrix(
[0., 0., 1., 0.],
[0., 0., 0., 1.]])
"""
num_classes = get_num_classes(pred, target, None)
num_classes = get_num_classes(pred, target, num_classes)

unique_labels = target.view(-1) * num_classes + pred.view(-1)

Expand Down

0 comments on commit 83927a9

Please sign in to comment.