Skip to content

Commit

Permalink
redo sklearn metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored and Borda committed Apr 17, 2020
1 parent 7a8f638 commit b3b6471
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 0 deletions.
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ dependencies:
- autopep8
- check-manifest
- twine==1.13.0
- pillow<7.0.0
- scikit-learn>=0.16.1

- pip:
- test-tube>=0.7.5
Expand Down
141 changes: 141 additions & 0 deletions pytorch_lightning/metrics/sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from typing import Any, Optional, Union

import numpy as np

import torch

from pytorch_lightning import _logger as lightning_logger
from pytorch_lightning.metrics.metric import NumpyMetric


class SklearnMetric(NumpyMetric):
def __init__(self, metric_name: str,
reduce_group: Any = torch.distributed.group.WORLD,
reduce_op: Any = torch.distributed.ReduceOp.SUM, **kwargs):
"""
Bridge between PyTorch Lightning and scikit-learn metrics
.. warning::
Every metric call will cause a GPU synchronization, which may slow down your code
.. note::
The order of targets and predictions may be different from the order typically used in PyTorch
Args:
metric_name: the metric name to import anc compute from scikit-learn.metrics
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
**kwargs: additonal keyword arguments (will be forwarded to metric call)
"""
super().__init__(name=metric_name, reduce_group=reduce_group,
reduce_op=reduce_op)

self.metric_kwargs = kwargs

lightning_logger.debug(
'Every metric call will cause a GPU synchronization, which may slow down your code')

@property
def metric_fn(self):
import sklearn.metrics
return getattr(sklearn.metrics, self.name)

def forward(self, *args, **kwargs) -> Union[np.ndarray, int, float]:
"""
Carries the actual metric computation and therefore co
Args:
*args: Positional arguments forwarded to metric call (should be already converted to numpy)
**kwargs: keyword arguments forwarded to metric call (should be already converted to numpy)
Returns:
the metric value (will be converted to tensor by baseclass
"""
return self.metric_fn(*args, **kwargs)


# metrics : accuracy, auc, average_precision (AP), confusion_matrix, f1, fbeta, hamm, precision, recall, precision_recall_curve, roc, roc_auc, r2, jaccard

class Accuracy(SklearnMetric):
def __init__(self, normalize: bool = True,
reduce_group: Any = torch.distributed.group.WORLD,
reduce_op: Any = torch.distributed.ReduceOp.SUM):
"""
Calculates the Accuracy Score
.. warning::
Every metric call will cause a GPU synchronization, which may slow down your code
Args:
normalize: If ``False``, return the number of correctly classified samples.
Otherwise, return the fraction of correctly classified samples.
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""
super().__init__(metric_name='accuracy_score',
reduce_group=reduce_group,
reduce_op=reduce_op,
normalize=normalize)

def forward(self, y_pred: np.ndarray, y_true: np.ndarray,
sample_weight: Optional[np.ndarray] = None) -> float:
"""
Computes the accuracy
Args:
y_pred: the array containing the predictions (already in categorical form)
y_true: the array containing the targets (in categorical form)
sample_weight:
Returns:
Accuracy Score
"""
return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight)

class AUC(SklearnMetric):
def __init__(self, reorder: bool = False,
reduce_group: Any = torch.distributed.group.WORLD,
reduce_op: Any = torch.distributed.ReduceOp.SUM
):
"""
Calculates the Area Under the Curve using the trapoezoidal rule
.. warning::
Every metric call will cause a GPU synchronization, which may slow down your code
Args:
reorder: If ``True``, assume that the curve is ascending in the case of ties, as for an ROC curve.
If the curve is non-ascending, the result will be wrong.
reduce_group: the process group for DDP reduces (only needed for DDP training).
Defaults to all processes (world)
reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
Defaults to sum.
"""

super().__init__(metric_name='auc',
reduce_group=reduce_group,
reduce_op=reduce_op,
reorder=reorder)

def forward(self, x: np.ndarray, y: np.ndarray) -> float:
"""
Computes the AUC
Args:
x: x coordinates.
y: y coordinates.
Returns:
AUC calculated with trapezoidal rule
"""
return super().forward(x=x, y=y)





1 change: 1 addition & 0 deletions requirements-extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ mlflow>=1.0.0
test_tube>=0.7.5
wandb>=0.8.21
trains>=0.14.1
scikit-learn>=0.16.1

0 comments on commit b3b6471

Please sign in to comment.