From 36c88c1e54f5fa2b268f92eb7acac6c09bf0680a Mon Sep 17 00:00:00 2001 From: Justus Schock Date: Mon, 13 Apr 2020 12:56:43 +0200 Subject: [PATCH] add sklearn metrics --- pytorch_lightning/metrics/sklearn.py | 385 ++++++++++++++++++++++++++- 1 file changed, 373 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/metrics/sklearn.py b/pytorch_lightning/metrics/sklearn.py index d9a8c24d2dd62..d084f18e810b8 100644 --- a/pytorch_lightning/metrics/sklearn.py +++ b/pytorch_lightning/metrics/sklearn.py @@ -177,7 +177,7 @@ def forward(self, y_score: np.ndarray, y_true: np.ndarray, sample_weight=sample_weight) -class ConfusionMatric(SklearnMetric): +class ConfusionMatrix(SklearnMetric): def __init__(self, labels: Optional[Sequence] = None, reduce_group: Any = torch.distributed.group.WORLD, reduce_op: Any = torch.distributed.ReduceOp.SUM @@ -296,33 +296,394 @@ def forward(self, y_pred: np.ndarray, y_true: np.ndarray, class FBeta(SklearnMetric): + """ + Compute the F-beta score.The `beta` parameter determines the weight of precision in the combined + score. ``beta < 1`` lends more weight to precision, while ``beta > 1`` + favors recall (``beta -> 0`` considers only precision, ``beta -> inf`` + only recall). + + References: + .. [1] R. Baeza-Yates and B. Ribeiro-Neto (2011). + Modern Information Retrieval. Addison Wesley, pp. 327-328. + .. [2] `Wikipedia entry for the F1-score + `_ + """ - pass + def __init__(self, beta: float, labels: Optional[Sequence] = None, + pos_labels: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + + Args: + beta: Weight of precision in harmonic mean. + labels: Integer array of labels. + pos_labels: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('fbeta_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + beta=beta, + labels=labels, + pos_labels=pos_labels, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Returns: FBeta score of the positive class in binary classification or weighted + average of the FBeta scores of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) class Precision(SklearnMetric): - pass + """ + Compute the precision + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + The best value is 1 and the worst value is 0. + + """ + def __init__(self, labels: Optional[Sequence] = None, + pos_labels: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + + Args: + labels: Integer array of labels. + pos_labels: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('precision_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_labels=pos_labels, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Returns: Precision of the positive class in binary classification or weighted + average of the precision of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) class Recall(SklearnMetric): - pass + """ + Compute the recall + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + The best value is 1 and the worst value is 0. + + """ + + def __init__(self, labels: Optional[Sequence] = None, + pos_labels: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + + Args: + labels: Integer array of labels. + pos_labels: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('recall_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_labels=pos_labels, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Returns: Recall of the positive class in binary classification or weighted + average of the recall of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) class PrecisionRecallCurve(SklearnMetric): - pass + """ + Compute precision-recall pairs for different probability thresholds + + Note: + this implementation is restricted to the binary classification task. + + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + The last precision and recall values are 1. and 0. respectively and do not + have a corresponding threshold. This ensures that the graph starts on the + x axis. + + """ + + def __init__(self, + pos_labels: Union[str, int] = 1, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + + Args: + pos_labels: The class to report if ``average='binary'``. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('precision_recall_curve', + reduce_group=reduce_group, + reduce_op=reduce_op, + pos_labels=pos_labels) + + def forward(self, probas_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + + Args: + probas_pred : Estimated probabilities or decision function. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Returns: + precision: + Precision values such that element i is the precision of + predictions with score >= thresholds[i] and the last element is 1. + recall: + Decreasing recall values such that element i is the recall of + predictions with score >= thresholds[i] and the last element is 0. + thresholds: + Increasing thresholds on the decision function used to compute + precision and recall. + + """ + return super().forward(probas_pred=probas_pred, y_true=y_true, sample_weight=sample_weight) class ROC(SklearnMetric): - pass + """ + Compute Receiver operating characteristic (ROC) + Note: + this implementation is restricted to the binary classification task. -class AUROC(SklearnMetric): - pass + """ + + def __init__(self, + pos_labels: Union[str, int] = 1, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + + Args: + pos_labels: The class to report if ``average='binary'``. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + References: + .. [1] `Wikipedia entry for the Receiver operating characteristic + `_ + """ + super().__init__('roc_curve', + reduce_group=reduce_group, + reduce_op=reduce_op, + pos_labels=pos_labels) -class R2(SklearnMetric): - pass + def forward(self, y_score: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + + Args: + y_score : Target scores, can either be probability estimates of the positive + class or confidence values. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Returns: + fpr: + Increasing false positive rates such that element i is the false + positive rate of predictions with score >= thresholds[i]. + tpr: + Increasing true positive rates such that element i is the true + positive rate of predictions with score >= thresholds[i]. + thresholds: + Decreasing thresholds on the decision function used to compute + fpr and tpr. `thresholds[0]` represents no instances being predicted + and is arbitrarily set to `max(y_score) + 1`. + + """ + return super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight) + + +class AUROC(SklearnMetric): + """ + Compute Area Under the Curve (AUC) from prediction scores + Note: + this implementation is restricted to the binary classification task + or multilabel classification task in label indicator format. + """ + def __init__(self, average: Optional[str] = 'macro', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM + ): + """ + Args: + average: If None, the scores for each class are returned. Otherwise, this determines the type of + averaging performed on the data: + * If 'micro': Calculate metrics globally by considering each element of the label indicator + matrix as a label. + * If 'macro': Calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + * If 'weighted': Calculate metrics for each label, and find their average, weighted by + support (the number of true instances for each label). + * If 'samples': Calculate metrics for each instance, and find their average. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('roc_auc_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + average=average) + def forward(self, y_score: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> float: + """ -class Jaccard(SklearnMetric): - pass + Args: + y_score: Target scores, can either be probability estimates of the positive class, + confidence values, or binary decisions. + y_true: True binary labels in binary label indicators. + sample_weight: Sample weights. + Returns: + Area Under Receiver Operating Characteristic Curve + """ + return super().forward(y_score=y_score, y_true=y_true, + sample_weight=sample_weight)