diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 9f1d15e7fe185f..443725c1f27405 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -144,14 +144,19 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor, Args: pred: predicted labels + target: ground truth labels + num_classes: number of classes + reduction: method for reducing accuracies over labels (default: takes the mean) Available reduction methods: - elementwise_mean: takes the mean + - none: pass array + - sum: add elements Output: @@ -170,6 +175,20 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor, def confusion_matrix(pred: torch.Tensor, target: torch.Tensor, normalize: bool = False) -> torch.Tensor: + ''' + Computes the confusion matrix C where each entry C_{i,j} is the number of observations + in group i that were predicted in group j. + + Args: + pred: estimated targets + + target: groud truth labels + + normalize: normalizes confusion matrix + + Output: + Tensor, confusion matrix C [num_classes, num_classes ] + ''' num_classes = get_num_classes(pred, target, None) d = target.size(-1) @@ -190,6 +209,30 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean' ) -> Tuple[torch.Tensor, torch.Tensor]: + ''' + Computes precision and recall for different thresholds + + Args: + + pred: estimated probabilities + + target: ground-truth labels + + num_classes: number of classes + + reduction: method for reducing precision-recall values (default: takes the mean) + + Available reduction methods: + + - elementwise_mean: takes the mean + + - none: pass array + + - sum: add elements + + Output: + Tensor with precision and recall + ''' tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -209,6 +252,30 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor, def precision(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean') -> torch.Tensor: + ''' + Computes precision score. + + Args: + + pred: estimated probabilities + + target: ground-truth labels + + num_classes: number of classes + + reduction: method for reducing precision values (default: takes the mean) + + Available reduction methods: + + - elementwise_mean: takes the mean + + - none: pass array + + - sum: add elements + + Output: + Tensor with precision. + ''' return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[0] @@ -216,6 +283,30 @@ def precision(pred: torch.Tensor, target: torch.Tensor, def recall(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean') -> torch.Tensor: + ''' + Computes recall score. + + Args: + + pred: estimated probabilities + + target: ground-truth labels + + num_classes: number of classes + + reduction: method for reducing recall values (default: takes the mean) + + Available reduction methods: + + - elementwise_mean: takes the mean + + - none: pass array + + - sum: add elements + + Output: + Tensor with recall. + ''' return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[1]