From 8d9f53adaff8f934744ab5c1507f5d10978a1dfe Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 15 Jun 2020 19:22:34 -0400 Subject: [PATCH] fixes --- docs/source/conf.py | 2 +- pytorch_lightning/metrics/classification.py | 86 ++++---- .../metrics/functional/classification.py | 191 +++++++++++++++++- 3 files changed, 222 insertions(+), 57 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 621dd14b9c732..d81c0d00e7da5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -90,7 +90,7 @@ 'sphinx.ext.linkcode', 'sphinx.ext.autosummary', 'sphinx.ext.napoleon', - # 'sphinx.ext.imgmath', + 'sphinx.ext.imgmath', 'recommonmark', 'sphinx.ext.autosectionlabel', # 'm2r', diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index ad3f8779213c8..ab69a4fed598b 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -66,7 +66,7 @@ def __init__( >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = Accuracy() >>> metric(pred, target) - tensor([0.7500]) + tensor(0.7500) """ super().__init__(name='accuracy', @@ -111,14 +111,13 @@ def __init__( Example: - >>> pred = torch.tensor([0, 1, 2, 3]) + >>> pred = torch.tensor([0, 1, 2, 2]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = ConfusionMatrix() >>> metric(pred, target) - tensor([[1., 0., 0., 0.], - [0., 1., 0., 0.], - [0., 0., 1., 1.], - [0., 0., 0., 0.]]) + tensor([[1., 0., 0.], + [0., 1., 0.], + [0., 0., 2.]]) """ super().__init__(name='confusion_matrix', @@ -163,8 +162,11 @@ def __init__( >>> pred = torch.tensor([0, 1, 2, 3]) >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = PrecisionRecall() - >>> metric(pred, target) - (tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([1., 0., 0., 0.]), tensor([1., 2., 3.])) + >>> pr, rc, th = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.3333, 0.0000, 0.0000, 1.0000]), + tensor([1., 0., 0., 0.]), + tensor([1., 2., 3.])) """ super().__init__(name='precision_recall_curve', @@ -226,7 +228,7 @@ def __init__( >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = PrecisionRecall() >>> metric(pred, target) - tensor(1.) + (tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([1., 0., 0., 0.]), tensor([1., 2., 3.])) """ super().__init__(name='precision', @@ -548,6 +550,7 @@ def __init__( >>> target = torch.tensor([0, 1, 2, 2]) >>> metric = ROC() >>> fp, tp, thresholds = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE (tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), tensor([0., 0., 0., 1., 1.]), tensor([4., 3., 2., 1., 0.])) @@ -607,25 +610,18 @@ def __init__( Example: - .. testcode:: - - pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - [0.05, 0.85, 0.05, 0.05], - [0.05, 0.05, 0.85, 0.05], - [0.05, 0.05, 0.05, 0.85]]) - target = torch.tensor([0, 1, 3, 2]) - metric = MulticlassROC() - classes_roc = metric(pred, target) - - Out: - - .. testoutput:: - + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassROC() + >>> classes_roc = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) - """ super().__init__(name='multiclass_roc', reduce_group=reduce_group, @@ -678,20 +674,14 @@ def __init__( Example: - .. testcode:: - - pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - [0.05, 0.85, 0.05, 0.05], - [0.05, 0.05, 0.85, 0.05], - [0.05, 0.05, 0.05, 0.85]]) - target = torch.tensor([0, 1, 3, 2]) - metric = MulticlassPrecisionRecall() - classes_pr = metric(pred, target) - - Out: - - .. testoutput:: - + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = MulticlassPrecisionRecall() + >>> classes_pr = metric(pred, target) + >>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE ((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])), @@ -756,18 +746,14 @@ def __init__( .. testcode: - pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], - [0.05, 0.85, 0.05, 0.05], - [0.05, 0.05, 0.85, 0.05], - [0.05, 0.05, 0.05, 0.85]]) - target = torch.tensor([0, 1, 3, 2]) - metric = DiceCoefficient() - classes_pr = metric(pred, target) - - Out: - - .. testoutput: - + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> metric = DiceCoefficient() + >>> classes_pr = metric(pred, target) + >>> metric(pred, target) tensor(0.3333) """ super().__init__(name='dice', diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 24fc68013c101..9525f51df2203 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -22,7 +22,7 @@ def to_onehot( Output: A sparse label tensor with shape [N, C, d1, d2, ...] - Example:: + Example: >>> x = torch.tensor([1, 2, 3]) >>> to_onehot(x) @@ -50,6 +50,13 @@ def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: Return: A tensor with categorical labels [N, d2, ...] + + Example: + + >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + >>> to_categorical(x) + tensor([1, 0]) + """ return torch.argmax(tensor, dim=argmax_dim) @@ -74,7 +81,8 @@ def get_num_classes( if pred.ndim > target.ndim: num_classes = pred.size(1) else: - num_classes = int(target.max().detach().item() + 1) + num_target_classes = int(target.max().detach().item() + 1) + num_classes = num_target_classes return num_classes @@ -97,6 +105,18 @@ def stat_scores( Return: Tensors in the following order: True Positive, False Positive, True Negative, False Negative + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tp, fp, tn, fn, sup = stat_scores(x, y, class_index=1) + >>> stat_scores(x, y, class_index=1) # doctest: +NORMALIZE_WHITESPACE + (tensor(0), + tensor(1), + tensor(2), + tensor(0), + tensor(0)) + """ if pred.ndim == target.ndim + 1: pred = to_categorical(pred, argmax_dim=argmax_dim) @@ -131,6 +151,17 @@ def stat_scores_multiple_classes( Return: Returns tensors for: tp, fp, tn, fn + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> tps, fps, tns, fns, sups = stat_scores_multiple_classes(x, y) + >>> stat_scores_multiple_classes(x, y) # doctest: +NORMALIZE_WHITESPACE + (tensor([0., 0., 1., 1.]), + tensor([0., 1., 0., 0.]), + tensor([2., 2., 2., 2.]), + tensor([1., 0., 0., 0.]), + tensor([1., 0., 1., 1.])) """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) @@ -144,9 +175,7 @@ def stat_scores_multiple_classes( fns = torch.zeros((num_classes,), device=pred.device) sups = torch.zeros((num_classes,), device=pred.device) for c in range(num_classes): - tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, - target=target, - class_index=c) + tps[c], fps[c], tns[c], fns[c], sups[c] = stat_scores(pred=pred, target=target, class_index=c) return tps, fps, tns, fns, sups @@ -173,6 +202,13 @@ def accuracy( Return: A Tensor with the classification score. + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> accuracy(x, y) + tensor(0.6667) """ tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, num_classes=num_classes) @@ -202,6 +238,16 @@ def confusion_matrix( Return: Tensor, confusion matrix C [num_classes, num_classes ] + + Example: + + >>> x = torch.tensor([1, 2, 3]) + >>> y = torch.tensor([0, 2, 3]) + >>> confusion_matrix(x, y) + tensor([[0., 1., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 1., 0.], + [0., 0., 0., 1.]]) """ num_classes = get_num_classes(pred, target, None) @@ -238,6 +284,14 @@ def precision_recall( Return: Tensor with precision and recall + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision_recall(x, y) + (tensor(1.), tensor(0.8333)) + """ tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred=pred, target=target, @@ -277,6 +331,14 @@ def precision( Return: Tensor with precision. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> precision(x, y) + tensor(1.) + """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[0] @@ -304,6 +366,13 @@ def recall( Return: Tensor with recall. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> recall(x, y) + tensor(0.8333) """ return precision_recall(pred=pred, target=target, num_classes=num_classes, reduction=reduction)[1] @@ -338,6 +407,13 @@ def fbeta_score( Return: Tensor with the value of F-score. It is a value between 0-1. + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fbeta_score(x, y, 0.2) + tensor(0.9877) """ prec, rec = precision_recall(pred=pred, target=target, num_classes=num_classes, @@ -372,6 +448,13 @@ def f1_score( Return: Tensor containing F1-score + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> f1_score(x, y) + tensor(0.8889) """ return fbeta_score(pred=pred, target=target, beta=1., num_classes=num_classes, reduction=reduction) @@ -440,6 +523,16 @@ def roc( Return: [Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> fpr, tpr, thresholds = roc(x,y) + >>> roc(x,y) # doctest: +NORMALIZE_WHITESPACE + (tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]), + tensor([0., 0., 0., 1., 1.]), + tensor([4, 3, 2, 1, 0])) """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, @@ -482,6 +575,19 @@ def multiclass_roc( Return: [num_classes, Tensor, Tensor, Tensor]: returns roc for each class. number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE + ((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])), + (tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500]))) """ num_classes = get_num_classes(pred, target, num_classes) @@ -512,6 +618,15 @@ def precision_recall_curve( Return: [Tensor, Tensor, Tensor]: precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([0, 1, 2, 3]) + >>> target = torch.tensor([0, 1, 2, 2]) + >>> precision, recall, thresholds = precision_recall_curve(pred, target) + >>> precision_recall_curve(pred, target) + (tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([1., 0., 0., 0.]), tensor([1., 2., 3.])) + """ fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target, sample_weight=sample_weight, @@ -556,7 +671,20 @@ def multiclass_precision_recall_curve( num_classes: number of classes Return: - [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds + [num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> classes_pr = multiclass_precision_recall_curve(pred, target) + >>> classes_pr # doctest: +NORMALIZE_WHITESPACE + ((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), (tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])), + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])), + (tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500]))) """ num_classes = get_num_classes(pred, target, num_classes) @@ -583,6 +711,13 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True): Return: AUC score (float) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auc(x, y) + tensor(4.) """ direction = 1. @@ -644,6 +779,13 @@ def auroc( target: ground-truth labels sample_weight: sample weights pos_label: the label for the positive class (default: 1.) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> auroc(x, y) + tensor(0.3333) """ @auc_decorator(reorder=True) @@ -659,6 +801,21 @@ def average_precision( sample_weight: Optional[Sequence] = None, pos_label: int = 1., ) -> torch.Tensor: + """ + + Args: + pred: estimated probabilities + target: ground-truth labels + sample_weight: sample weights + pos_label: the label for the positive class (default: 1.) + + Example: + + >>> x = torch.tensor([0, 1, 2, 3]) + >>> y = torch.tensor([0, 1, 2, 2]) + >>> average_precision(x, y) + tensor(0.3333) + """ precision, recall, _ = precision_recall_curve(pred=pred, target=target, sample_weight=sample_weight, pos_label=pos_label) @@ -676,6 +833,28 @@ def dice_score( no_fg_score: float = 0.0, reduction: str = 'elementwise_mean', ) -> torch.Tensor: + """ + @nicki to finish + + Args: + pred: + target: + bg: + nan_score: + no_fg_score: + reduction: + + Example: + + >>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05], + ... [0.05, 0.85, 0.05, 0.05], + ... [0.05, 0.05, 0.85, 0.05], + ... [0.05, 0.05, 0.05, 0.85]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> average_precision(pred, target) + tensor(0.2500) + + """ n_classes = pred.shape[1] bg = (1 - int(bool(bg))) scores = torch.zeros(n_classes - bg, device=pred.device, dtype=torch.float32)