diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index e196c107f5d74..3fd7c68385fae 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -9,7 +9,8 @@ def to_onehot(tensor: torch.Tensor, n_classes: Optional[int] = None) -> torch.Tensor: - """ Converts a dense label tensor to one-hot format + """ + Converts a dense label tensor to one-hot format Args: tensor: dense label tensor, with shape [N, d1, d2, ...] @@ -29,14 +30,31 @@ def to_onehot(tensor: torch.Tensor, def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: - """ Converts a tensor of probabilities to a dense label tensor """ + """ + Converts a tensor of probabilities to a dense label tensor + + Args: + tensor: probabilities to get the categorical label [N, d1, d2, ...] + argmax_dim: dimension to apply (default: 1) + + Return: + A tensor with categorical labels [N, d2, ...] + """ return torch.argmax(tensor, dim=argmax_dim) def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int]) -> int: - """ Returns the number of classes for a given prediction and - target tensor + """ + Returns the number of classes for a given prediction and target tensor. + + Args: + pred: predicted values + target: true labels + num_classes: number of classes if known (default: None) + + Return: + An integer that represents the number of classes. """ if num_classes is None: if pred.ndim > target.ndim: @@ -50,8 +68,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor, class_index: int, argmax_dim: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ Calculates the number of true postive, false postive, true negative - and false negative for a specfic class + """ + Calculates the number of true positive, falsepositivee, true negative + and false negative for a specific class Args: pred: prediction tensor @@ -63,6 +82,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor, argmax_dim: if pred is a tensor of probabilities, this indicates the axis the argmax transformation will be applied over + Return: + Tensors in the following order: True Positive, False Positive, True Negative, False Negative + """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -80,20 +102,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor, argmax_dim: int = 1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ Calls the stat_scores function iteratively for all classes, thus - calculating the number of true postive, false postive, true negative - and false negative for each class + """ + Calls the stat_scores function iteratively for all classes, thus + calculating the number of true postive, false postive, true negative + and false negative for each class Args: pred: prediction tensor - target: target tensor - class_index: class to calculate over - argmax_dim: if pred is a tensor of probabilities, this indicates the axis the argmax transformation will be applied over + Return: + Returns tensors for: tp, fp, tn, fn + """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) @@ -116,6 +139,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor, def accuracy(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction='elementwise_mean') -> torch.Tensor: + """ + Computes the accuracy classification score + + Args: + pred: predicted labels + target: ground truth labels + num_classes: number of classes + reduction: a method for reducing accuracies over labels (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + A Tensor with the classification score. + """ tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -129,6 +169,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor, def confusion_matrix(pred: torch.Tensor, target: torch.Tensor, normalize: bool = False) -> torch.Tensor: + """ + Computes the confusion matrix C where each entry C_{i,j} is the number of observations + in group i that were predicted in group j. + + Args: + pred: estimated targets + target: ground truth labels + normalize: normalizes confusion matrix + + Return: + Tensor, confusion matrix C [num_classes, num_classes ] + """ num_classes = get_num_classes(pred, target, None) d = target.size(-1) @@ -149,6 +201,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean' ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Computes precision and recall for different thresholds + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing precision-recall values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with precision and recall + """ tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -168,6 +237,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor, def precision(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean') -> torch.Tensor: + """ + Computes precision score. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing precision values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with precision. + """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[0] @@ -175,6 +261,23 @@ def precision(pred: torch.Tensor, target: torch.Tensor, def recall(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean') -> torch.Tensor: + """ + Computes recall score. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing recall values (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements + + Return: + Tensor with recall. + """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[1] @@ -182,6 +285,29 @@ def recall(pred: torch.Tensor, target: torch.Tensor, def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float, num_classes: Optional[int] = None, reduction: str = 'elementwise_mean') -> torch.Tensor: + """ + Computes the F-beta score which is a weighted harmonic mean of precision and recall. + It ranges between 1 and 0, where 1 is perfect and the worst value is 0. + + Args: + pred: estimated probabilities + target: ground-truth labels + beta: weights recall when combining the score. + beta < 1: more weight to precision. + beta > 1 more weight to recall + beta = 0: only precision + beta -> inf: only recall + num_classes: number of classes + reduction: method for reducing F-score (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements. + + Return: + Tensor with the value of F-score. It is a value between 0-1. + """ prec, rec = precision_recall(pred=pred, target=target, num_classes=num_classes, reduction='none') @@ -196,6 +322,23 @@ def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float, def f1_score(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None, reduction='elementwise_mean') -> torch.Tensor: + """ + Computes F1-score a.k.a F-measure. + + Args: + pred: estimated probabilities + target: ground-truth labels + num_classes: number of classes + reduction: method for reducing F1-score (default: takes the mean) + Available reduction methods: + + - elementwise_mean: takes the mean + - none: pass array + - sum: add elements. + + Return: + Tensor containing F1-score + """ return fbeta_score(pred=pred, target=target, beta=1., num_classes=num_classes, reduction=reduction) @@ -251,6 +394,18 @@ def roc(pred: torch.Tensor, target: torch.Tensor, pos_label: int = 1.) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class (default: 1) + + Return: + [Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds + """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -282,6 +437,19 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes the Receiver Operating Characteristic (ROC) for multiclass predictors. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + num_classes: number of classes (default: None, computes automatically from data) + + Return: + [num_classes, Tensor, Tensor, Tensor]: returns roc for each class. + number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds + """ num_classes = get_num_classes(pred, target, num_classes) class_roc_vals = [] @@ -301,6 +469,18 @@ def precision_recall_curve(pred: torch.Tensor, pos_label: int = 1.) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes precision-recall pairs for different thresholds. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class (default: 1.) + + Return: + [Tensor, Tensor, Tensor]: precision, recall, thresholds + """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -334,6 +514,18 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor, ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Computes precision-recall pairs for different thresholds given a multiclass scores. + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weight + num_classes: number of classes + + Return: + [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds + """ num_classes = get_num_classes(pred, target, num_classes) class_pr_vals = [] @@ -350,6 +542,17 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor, def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True): + """ + Computes Area Under the Curve (AUC) using the trapezoidal rule + + Args: + x: x-coordinates + y: y-coordinates + reorder: reorder coordinates, so they are increasing. + + Return: + AUC score (float) + """ direction = 1. if reorder: @@ -400,6 +603,15 @@ def new_func(*args, **kwargs) -> torch.Tensor: def auroc(pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, pos_label: int = 1.) -> torch.Tensor: + """ + Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class (default: 1.) + """ return roc(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -410,7 +622,6 @@ def average_precision(pred: torch.Tensor, target: torch.Tensor, precision, recall, _ = precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) - # Return the step function integral # The following works because the last entry of precision is # guaranteed to be 1, as returned by precision_recall_curve