From 768579dde1f6770703a97d6b2c9da015f158670f Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Wed, 10 Jun 2020 15:43:12 +0200 Subject: [PATCH] Rework of Sklearn Metrics (#1327) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Create utils.py * Create __init__.py * redo sklearn metrics * add some more metrics * add sklearn metrics * Create __init__.py * redo sklearn metrics * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec * add sklearn metrics * start adding sklearn tests * fix typo * return x and y only for curves * fix typo * add missing tests for sklearn funcs * imports * __all__ * imports * fix sklearn arguments * fix imports * update requirements * Update CHANGELOG.md * Update test_sklearn_metrics.py * formatting * formatting * format * fix all warnings and formatting problems * Update environment.yml * Update requirements-extra.txt * Update environment.yml * Update requirements-extra.txt * fix all warnings and formatting problems * Update CHANGELOG.md * docs * inherit * docs inherit. * docs * Apply suggestions from code review Co-authored-by: Nicki Skafte * docs * req * min * Apply suggestions from code review Co-authored-by: Tullie Murrell Co-authored-by: Jirka Borovec Co-authored-by: Jirka Co-authored-by: Adrian Wälchli Co-authored-by: Nicki Skafte Co-authored-by: Tullie Murrell (cherry picked from commit bd49b07fbba09b1e7d8851ee5a1ffce3d5925e9e) --- .circleci/config.yml | 10 +- .github/workflows/ci-testing.yml | 6 +- CHANGELOG.md | 4 +- docs/source/conf.py | 1 + environment.yml | 4 + .../computer_vision_fine_tuning.py | 12 +- pytorch_lightning/core/grads.py | 3 +- pytorch_lightning/core/hooks.py | 3 +- pytorch_lightning/metrics/__init__.py | 6 + pytorch_lightning/metrics/metric.py | 5 +- pytorch_lightning/metrics/sklearn.py | 716 ++++++++++++++++++ pytorch_lightning/metrics/utils.py | 130 ++++ .../utilities/device_dtype_mixin.py | 17 +- requirements-extra.txt | 4 +- requirements.txt | 2 +- tests/metrics/test_sklearn_metrics.py | 86 +++ tests/requirements.txt | 2 +- 17 files changed, 983 insertions(+), 28 deletions(-) create mode 100644 pytorch_lightning/metrics/sklearn.py create mode 100644 pytorch_lightning/metrics/utils.py create mode 100644 tests/metrics/test_sklearn_metrics.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 7bd1de8f6c947..2b7f2ad578a32 100755 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -64,8 +64,12 @@ references: name: Make Documentation command: | # First run the same pipeline as Read-The-Docs - sudo apt-get update && sudo apt-get install -y cmake - sudo pip install -r docs/requirements.txt + # apt-get update && apt-get install -y cmake + # using: https://hub.docker.com/r/readthedocs/build + # we need to use py3.7 ot higher becase of an issue with metaclass inheritence + pyenv global 3.7.3 + python --version + pip install -r docs/requirements.txt cd docs; make clean; make html --debug --jobs 2 SPHINXOPTS="-W" test_docs: &test_docs @@ -81,7 +85,7 @@ jobs: Build-Docs: docker: - - image: circleci/python:3.7 + - image: readthedocs/build:latest steps: - checkout - *make_docs diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 29711c0c62295..d905df63dec61 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -68,9 +68,9 @@ jobs: - name: Set min. dependencies if: matrix.requires == 'minimal' run: | - python -c "req = open('requirements.txt').read().replace('>', '=') ; open('requirements.txt', 'w').write(req)" - python -c "req = open('requirements-extra.txt').read().replace('>', '=') ; open('requirements-extra.txt', 'w').write(req)" - python -c "req = open('tests/requirements-devel.txt').read().replace('>', '=') ; open('tests/requirements-devel.txt', 'w').write(req)" + python -c "req = open('requirements.txt').read().replace('>=', '==') ; open('requirements.txt', 'w').write(req)" + python -c "req = open('requirements-extra.txt').read().replace('>=', '==') ; open('requirements-extra.txt', 'w').write(req)" + python -c "req = open('tests/requirements-devel.txt').read().replace('>=', '==') ; open('tests/requirements-devel.txt', 'w').write(req)" # Note: This uses an internal pip API and may not always work # https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow diff --git a/CHANGELOG.md b/CHANGELOG.md index 80070aeb3dde6..ec6157d2bc1ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,6 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - ## [unreleased] - YYYY-MM-DD ### Added @@ -23,7 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Remove explicit flush from tensorboard logger ([#2126](https://github.com/PyTorchLightning/pytorch-lightning/pull/2126)) -- Add Metric Base Classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) +- Added metric Base classes ([#1326](https://github.com/PyTorchLightning/pytorch-lightning/pull/1326), [#1877](https://github.com/PyTorchLightning/pytorch-lightning/pull/1877)) +- Added Sklearn metrics classes ([#1327](https://github.com/PyTorchLightning/pytorch-lightning/pull/1327)) - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)) - Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907)) - Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908) diff --git a/docs/source/conf.py b/docs/source/conf.py index 4133571c65635..a084e5e349e39 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -90,6 +90,7 @@ 'sphinx.ext.linkcode', 'sphinx.ext.autosummary', 'sphinx.ext.napoleon', + 'sphinx.ext.imgmath', 'recommonmark', 'sphinx.ext.autosectionlabel', # 'm2r', diff --git a/environment.yml b/environment.yml index f2718a99c3a45..98f5fb81e1cdd 100644 --- a/environment.yml +++ b/environment.yml @@ -26,6 +26,10 @@ dependencies: - autopep8 - check-manifest - twine==1.13.0 + - pillow<7.0.0 + - scipy>=0.13.3 + - scikit-learn>=0.20.0 + - pip: - test-tube>=0.7.5 diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index e2db1b98fdb09..703f1c9b02419 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -27,6 +27,8 @@ from tempfile import TemporaryDirectory from typing import Optional, Generator, Union +from torch.nn import Module + import pytorch_lightning as pl import torch import torch.nn.functional as F @@ -47,7 +49,7 @@ # --- Utility functions --- -def _make_trainable(module: torch.nn.Module) -> None: +def _make_trainable(module: Module) -> None: """Unfreezes a given module. Args: @@ -58,7 +60,7 @@ def _make_trainable(module: torch.nn.Module) -> None: module.train() -def _recursive_freeze(module: torch.nn.Module, +def _recursive_freeze(module: Module, train_bn: bool = True) -> None: """Freezes the layers of a given module. @@ -80,7 +82,7 @@ def _recursive_freeze(module: torch.nn.Module, _recursive_freeze(module=child, train_bn=train_bn) -def freeze(module: torch.nn.Module, +def freeze(module: Module, n: Optional[int] = None, train_bn: bool = True) -> None: """Freezes the layers up to index n (if n is not None). @@ -101,7 +103,7 @@ def freeze(module: torch.nn.Module, _make_trainable(module=child) -def filter_params(module: torch.nn.Module, +def filter_params(module: Module, train_bn: bool = True) -> Generator: """Yields the trainable parameters of a given module. @@ -124,7 +126,7 @@ def filter_params(module: torch.nn.Module, yield param -def _unfreeze_and_add_param_group(module: torch.nn.Module, +def _unfreeze_and_add_param_group(module: Module, optimizer: Optimizer, lr: Optional[float] = None, train_bn: bool = True): diff --git a/pytorch_lightning/core/grads.py b/pytorch_lightning/core/grads.py index cb2215002c7d8..f58bbdf25ec88 100644 --- a/pytorch_lightning/core/grads.py +++ b/pytorch_lightning/core/grads.py @@ -4,9 +4,10 @@ from typing import Dict, Union import torch +from torch.nn import Module -class GradInformation(torch.nn.Module): +class GradInformation(Module): def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]: """Compute each parameter's gradient's norm and their overall norm. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 960c7124383b0..d3fea6d446845 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -2,6 +2,7 @@ import torch from torch import Tensor +from torch.nn import Module from torch.optim.optimizer import Optimizer from pytorch_lightning.utilities import move_data_to_device @@ -14,7 +15,7 @@ APEX_AVAILABLE = True -class ModelHooks(torch.nn.Module): +class ModelHooks(Module): # TODO: remove in v0.9.0 def on_sanity_check_start(self): diff --git a/pytorch_lightning/metrics/__init__.py b/pytorch_lightning/metrics/__init__.py index cd721851307df..83446d11701f9 100644 --- a/pytorch_lightning/metrics/__init__.py +++ b/pytorch_lightning/metrics/__init__.py @@ -22,3 +22,9 @@ """ + +from pytorch_lightning.metrics.metric import Metric, TensorMetric, NumpyMetric +from pytorch_lightning.metrics.sklearn import ( + SklearnMetric, Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta, + Precision, Recall, PrecisionRecallCurve, ROC, AUROC) +from pytorch_lightning.metrics.converters import numpy_metric, tensor_metric diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 5247084498559..bd14655f30fa3 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -1,8 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, Optional import torch import torch.distributed +from torch.nn import Module from pytorch_lightning.metrics.converters import tensor_metric, numpy_metric from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -11,7 +12,7 @@ __all__ = ['Metric', 'TensorMetric', 'NumpyMetric'] -class Metric(DeviceDtypeModuleMixin, torch.nn.Module, ABC): +class Metric(ABC, DeviceDtypeModuleMixin, Module): """ Abstract base class for metric implementation. diff --git a/pytorch_lightning/metrics/sklearn.py b/pytorch_lightning/metrics/sklearn.py new file mode 100644 index 0000000000000..60cc98c2c329f --- /dev/null +++ b/pytorch_lightning/metrics/sklearn.py @@ -0,0 +1,716 @@ +from typing import Any, Optional, Union, Sequence + +import numpy as np +import torch + +from pytorch_lightning import _logger as lightning_logger +from pytorch_lightning.metrics.metric import NumpyMetric + +__all__ = [ + 'SklearnMetric', + 'Accuracy', + 'AveragePrecision', + 'AUC', + 'ConfusionMatrix', + 'F1', + 'FBeta', + 'Precision', + 'Recall', + 'PrecisionRecallCurve', + 'ROC', + 'AUROC' +] + + +class SklearnMetric(NumpyMetric): + """ + Bridge between PyTorch Lightning and scikit-learn metrics + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + + Note: + The order of targets and predictions may be different from the order typically used in PyTorch + """ + def __init__(self, metric_name: str, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM, **kwargs): + """ + Args: + metric_name: the metric name to import and compute from scikit-learn.metrics + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + **kwargs: additonal keyword arguments (will be forwarded to metric call) + """ + super().__init__(name=metric_name, reduce_group=reduce_group, + reduce_op=reduce_op) + + self.metric_kwargs = kwargs + lightning_logger.debug( + f'Metric {self.__class__.__name__} is using Sklearn as backend, meaning that' + ' every metric call will cause a GPU synchronization, which may slow down your code' + ) + + @property + def metric_fn(self): + import sklearn.metrics + return getattr(sklearn.metrics, self.name) + + def forward(self, *args, **kwargs) -> Union[np.ndarray, int, float]: + """ + Carries the actual metric computation + + Args: + *args: Positional arguments forwarded to metric call (should be already converted to numpy) + **kwargs: keyword arguments forwarded to metric call (should be already converted to numpy) + + Return: + the metric value (will be converted to tensor by baseclass) + + """ + return self.metric_fn(*args, **kwargs, **self.metric_kwargs) + + +class Accuracy(SklearnMetric): + """ + Calculates the Accuracy Score + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + """ + def __init__(self, normalize: bool = True, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + normalize: If ``False``, return the number of correctly classified samples. + Otherwise, return the fraction of correctly classified samples. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__(metric_name='accuracy_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + normalize=normalize) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> float: + """ + Computes the accuracy + + Args: + y_pred: the array containing the predictions (already in categorical form) + y_true: the array containing the targets (in categorical form) + sample_weight: Sample weights. + + Return: + Accuracy Score + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class AUC(SklearnMetric): + """ + Calculates the Area Under the Curve using the trapoezoidal rule + + Warning: + Every metric call will cause a GPU synchronization, which may slow down your code + """ + def __init__(self, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM + ): + """ + Args: + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + + super().__init__(metric_name='auc', + reduce_group=reduce_group, + reduce_op=reduce_op, + ) + + def forward(self, x: np.ndarray, y: np.ndarray) -> float: + """ + Computes the AUC + + Args: + x: x coordinates. + y: y coordinates. + + Return: + AUC calculated with trapezoidal rule + + """ + return super().forward(x=x, y=y) + + +class AveragePrecision(SklearnMetric): + """ + Calculates the average precision (AP) score. + """ + def __init__(self, average: Optional[str] = 'macro', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM + ): + """ + Args: + average: If None, the scores for each class are returned. Otherwise, this determines the type of + averaging performed on the data: + + * If 'micro': Calculate metrics globally by considering each element of the label indicator + matrix as a label. + * If 'macro': Calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + * If 'weighted': Calculate metrics for each label, and find their average, weighted by + support (the number of true instances for each label). + * If 'samples': Calculate metrics for each instance, and find their average. + + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('average_precision_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + average=average) + + def forward(self, y_score: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> float: + """ + Args: + y_score: Target scores, can either be probability estimates of the positive class, + confidence values, or binary decisions. + y_true: True binary labels in binary label indicators. + sample_weight: Sample weights. + + Return: + average precision score + """ + return super().forward(y_score=y_score, y_true=y_true, + sample_weight=sample_weight) + + +class ConfusionMatrix(SklearnMetric): + """ + Compute confusion matrix to evaluate the accuracy of a classification + By definition a confusion matrix :math:`C` is such that :math:`C_{i, j}` + is equal to the number of observations known to be in group :math:`i` but + predicted to be in group :math:`j`. + """ + def __init__(self, labels: Optional[Sequence] = None, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM + ): + """ + Args: + labels: List of labels to index the matrix. This may be used to reorder + or select a subset of labels. + If none is given, those that appear at least once + in ``y_true`` or ``y_pred`` are used in sorted order. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('confusion_matrix', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: + """ + Args: + y_pred: Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + + Return: + Confusion matrix (array of shape [n_classes, n_classes]) + + """ + return super().forward(y_pred=y_pred, y_true=y_true) + + +class F1(SklearnMetric): + r""" + Compute the F1 score, also known as balanced F-score or F-measure + The F1 score can be interpreted as a weighted average of the precision and + recall, where an F1 score reaches its best value at 1 and worst score at 0. + The relative contribution of precision and recall to the F1 score are + equal. The formula for the F1 score is: + + .. math:: + + F_1 = 2 \cdot \frac{precision \cdot recall}{precision + recall} + + In the multi-class and multi-label case, this is the weighted average of + the F1 score of each class. + + References + - [1] `Wikipedia entry for the F1-score + `_ + """ + + def __init__(self, labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + labels: Integer array of labels. + pos_label: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + * ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + * ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + * ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + * ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + * ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('f1_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_label=pos_label, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + F1 score of the positive class in binary classification or weighted + average of the F1 scores of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class FBeta(SklearnMetric): + """ + Compute the F-beta score. The `beta` parameter determines the weight of precision in the combined + score. ``beta < 1`` lends more weight to precision, while ``beta > 1`` + favors recall (``beta -> 0`` considers only precision, ``beta -> inf`` + only recall). + + References: + - [1] R. Baeza-Yates and B. Ribeiro-Neto (2011). + Modern Information Retrieval. Addison Wesley, pp. 327-328. + - [2] `Wikipedia entry for the F1-score + `_ + """ + + def __init__(self, beta: float, labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + beta: Weight of precision in harmonic mean. + labels: Integer array of labels. + pos_label: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + * ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + * ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + * ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + * ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + * ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('fbeta_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + beta=beta, + labels=labels, + pos_label=pos_label, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + + Return: + FBeta score of the positive class in binary classification or weighted + average of the FBeta scores of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class Precision(SklearnMetric): + """ + Compute the precision + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + The best value is 1 and the worst value is 0. + """ + + def __init__(self, labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + labels: Integer array of labels. + pos_label: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + * ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + * ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + * ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + * ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + * ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('precision_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_label=pos_label, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Precision of the positive class in binary classification or weighted + average of the precision of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class Recall(SklearnMetric): + """ + Compute the recall + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + The best value is 1 and the worst value is 0. + """ + + def __init__(self, labels: Optional[Sequence] = None, + pos_label: Union[str, int] = 1, + average: Optional[str] = 'binary', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + labels: Integer array of labels. + pos_label: The class to report if ``average='binary'``. + average: This parameter is required for multiclass/multilabel targets. + If ``None``, the scores for each class are returned. Otherwise, this + determines the type of averaging performed on the data: + + * ``'binary'``: + Only report results for the class specified by ``pos_label``. + This is applicable only if targets (``y_{true,pred}``) are binary. + * ``'micro'``: + Calculate metrics globally by counting the total true positives, + false negatives and false positives. + * ``'macro'``: + Calculate metrics for each label, and find their unweighted + mean. This does not take label imbalance into account. + * ``'weighted'``: + Calculate metrics for each label, and find their average, weighted + by support (the number of true instances for each label). This + alters 'macro' to account for label imbalance; it can result in an + F-score that is not between precision and recall. + * ``'samples'``: + Calculate metrics for each instance, and find their average (only + meaningful for multilabel classification where this differs from + :func:`accuracy_score`). + + Note that if ``pos_label`` is given in binary classification with + `average != 'binary'`, only that positive class is reported. This + behavior is deprecated and will change in version 0.18. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('recall_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + labels=labels, + pos_label=pos_label, + average=average) + + def forward(self, y_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + y_pred : Estimated targets as returned by a classifier. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Return: + Recall of the positive class in binary classification or weighted + average of the recall of each class for the multiclass task. + + """ + return super().forward(y_pred=y_pred, y_true=y_true, sample_weight=sample_weight) + + +class PrecisionRecallCurve(SklearnMetric): + """ + Compute precision-recall pairs for different probability thresholds + + Note: + This implementation is restricted to the binary classification task. + + The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of + true positives and ``fp`` the number of false positives. The precision is + intuitively the ability of the classifier not to label as positive a sample + that is negative. + The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of + true positives and ``fn`` the number of false negatives. The recall is + intuitively the ability of the classifier to find all the positive samples. + The last precision and recall values are 1. and 0. respectively and do not + have a corresponding threshold. This ensures that the graph starts on the + x axis. + """ + + def __init__(self, + pos_label: Union[str, int] = 1, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + pos_label: The class to report if ``average='binary'``. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('precision_recall_curve', + reduce_group=reduce_group, + reduce_op=reduce_op, + pos_label=pos_label) + + def forward(self, probas_pred: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + probas_pred : Estimated probabilities or decision function. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Returns: + precision: + Precision values such that element i is the precision of + predictions with score >= thresholds[i] and the last element is 1. + recall: + Decreasing recall values such that element i is the recall of + predictions with score >= thresholds[i] and the last element is 0. + thresholds: + Increasing thresholds on the decision function used to compute + precision and recall. + + """ + # only return x and y here, since for now we cannot auto-convert elements of multiple length. + # Will be fixed in native implementation + return np.array( + super().forward(probas_pred=probas_pred, y_true=y_true, sample_weight=sample_weight)[:2]) + + +class ROC(SklearnMetric): + """ + Compute Receiver operating characteristic (ROC) + + Note: + this implementation is restricted to the binary classification task. + """ + + def __init__(self, + pos_label: Union[str, int] = 1, + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM): + """ + Args: + pos_labels: The class to report if ``average='binary'``. + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + + References: + - [1] `Wikipedia entry for the Receiver operating characteristic + `_ + """ + super().__init__('roc_curve', + reduce_group=reduce_group, + reduce_op=reduce_op, + pos_label=pos_label) + + def forward(self, y_score: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> Union[np.ndarray, float]: + """ + Args: + y_score : Target scores, can either be probability estimates of the positive + class or confidence values. + y_true: Ground truth (correct) target values. + sample_weight: Sample weights. + + Returns: + fpr: + Increasing false positive rates such that element i is the false + positive rate of predictions with score >= thresholds[i]. + tpr: + Increasing true positive rates such that element i is the true + positive rate of predictions with score >= thresholds[i]. + thresholds: + Decreasing thresholds on the decision function used to compute + fpr and tpr. `thresholds[0]` represents no instances being predicted + and is arbitrarily set to `max(y_score) + 1`. + + """ + return np.array(super().forward(y_score=y_score, y_true=y_true, sample_weight=sample_weight)[:2]) + + +class AUROC(SklearnMetric): + """ + Compute Area Under the Curve (AUC) from prediction scores + + Note: + this implementation is restricted to the binary classification task + or multilabel classification task in label indicator format. + """ + + def __init__(self, average: Optional[str] = 'macro', + reduce_group: Any = torch.distributed.group.WORLD, + reduce_op: Any = torch.distributed.ReduceOp.SUM + ): + """ + Args: + average: If None, the scores for each class are returned. Otherwise, this determines the type of + averaging performed on the data: + + * If 'micro': Calculate metrics globally by considering each element of the label indicator + matrix as a label. + * If 'macro': Calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + * If 'weighted': Calculate metrics for each label, and find their average, weighted by + support (the number of true instances for each label). + * If 'samples': Calculate metrics for each instance, and find their average. + + reduce_group: the process group for DDP reduces (only needed for DDP training). + Defaults to all processes (world) + reduce_op: the operation to perform during reduction within DDP (only needed for DDP training). + Defaults to sum. + """ + super().__init__('roc_auc_score', + reduce_group=reduce_group, + reduce_op=reduce_op, + average=average) + + def forward(self, y_score: np.ndarray, y_true: np.ndarray, + sample_weight: Optional[np.ndarray] = None) -> float: + """ + Args: + y_score: Target scores, can either be probability estimates of the positive class, + confidence values, or binary decisions. + y_true: True binary labels in binary label indicators. + sample_weight: Sample weights. + + Return: + Area Under Receiver Operating Characteristic Curve + """ + return super().forward(y_score=y_score, y_true=y_true, + sample_weight=sample_weight) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py new file mode 100644 index 0000000000000..0829c9711cb44 --- /dev/null +++ b/pytorch_lightning/metrics/utils.py @@ -0,0 +1,130 @@ +import numbers +from typing import Union, Any, Optional + +import numpy as np +import torch +from torch.utils.data._utils.collate import default_convert + +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def _apply_to_inputs(func_to_apply, *dec_args, **dec_kwargs): + def decorator_fn(func_to_decorate): + def new_func(*args, **kwargs): + args = func_to_apply(args, *dec_args, **dec_kwargs) + kwargs = func_to_apply(kwargs, *dec_args, **dec_kwargs) + return func_to_decorate(*args, **kwargs) + + return new_func + + return decorator_fn + + +def _apply_to_outputs(func_to_apply, *dec_args, **dec_kwargs): + def decorator_fn(function_to_decorate): + def new_func(*args, **kwargs): + result = function_to_decorate(*args, **kwargs) + return func_to_apply(result, *dec_args, **dec_kwargs) + + return new_func + + return decorator_fn + + +def _convert_to_tensor(data: Any) -> Any: + """ + Maps all kind of collections and numbers to tensors + + Args: + data: the data to convert to tensor + + Returns: + the converted data + + """ + if isinstance(data, numbers.Number): + return torch.tensor([data]) + else: + return default_convert(data) + + +def _convert_to_numpy(data: Union[torch.Tensor, np.ndarray, numbers.Number]) -> np.ndarray: + """ + converts all tensors and numpy arrays to numpy arrays + Args: + data: the tensor or array to convert to numpy + + Returns: + the resulting numpy array + + """ + if isinstance(data, torch.Tensor): + return data.cpu().detach().numpy() + elif isinstance(data, numbers.Number): + return np.array([data]) + return data + + +def _numpy_metric_conversion(func_to_decorate): + # Applies collection conversion from tensor to numpy to all inputs + # we need to include numpy arrays here, since otherwise they will also be treated as sequences + func_convert_inputs = _apply_to_inputs( + apply_to_collection, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy)(func_to_decorate) + # converts all inputs back to tensors (device doesn't matter here, since this is handled by BaseMetric) + func_convert_in_out = _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + return func_convert_in_out + + +def _tensor_metric_conversion(func_to_decorate): + # Converts all inputs to tensor if possible + func_convert_inputs = _apply_to_inputs(_convert_to_tensor)(func_to_decorate) + # convert all outputs to tensor if possible + return _apply_to_outputs(_convert_to_tensor)(func_convert_inputs) + + +def _sync_ddp(result: Union[torch.Tensor], + group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM, + ) -> torch.Tensor: + """ + Function to reduce the tensors from several ddp processes to one master process + + Args: + result: the value to sync and reduce (typically tensor or number) + device: the device to put the synced and reduced value to + dtype: the datatype to convert the synced and reduced value to + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum + + Returns: + reduced value + + """ + + if torch.distributed.is_available() and torch.distributed.is_initialized(): + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, + async_op=False) + + return result + + +def numpy_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM): + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp, + group=group, + reduce_op=reduce_op)(_numpy_metric_conversion(func_to_decorate)) + + return decorator_fn + + +def tensor_metric(group: Any = torch.distributed.group.WORLD, + reduce_op: torch.distributed.ReduceOp = torch.distributed.ReduceOp.SUM): + def decorator_fn(func_to_decorate): + return _apply_to_outputs(apply_to_collection, torch.Tensor, _sync_ddp, + group=group, + reduce_op=reduce_op)(_tensor_metric_conversion(func_to_decorate)) + + return decorator_fn diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index eb3faf54faf6e..48ccad5307552 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -1,9 +1,10 @@ from typing import Union, Optional import torch +from torch.nn import Module -class DeviceDtypeModuleMixin(torch.nn.Module): +class DeviceDtypeModuleMixin(Module): _device: ... _dtype: Union[str, torch.dtype] @@ -25,7 +26,7 @@ def device(self, new_device: Union[str, torch.device]): # Necessary to avoid infinite recursion raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') - def to(self, *args, **kwargs) -> torch.nn.Module: + def to(self, *args, **kwargs) -> Module: """Moves and/or casts the parameters and buffers. This can be called as @@ -91,7 +92,7 @@ def to(self, *args, **kwargs) -> torch.nn.Module: return super().to(*args, **kwargs) - def cuda(self, device: Optional[int] = None) -> torch.nn.Module: + def cuda(self, device: Optional[int] = None) -> Module: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will @@ -108,7 +109,7 @@ def cuda(self, device: Optional[int] = None) -> torch.nn.Module: self._device = torch.device('cuda', index=device) return super().cuda(device=device) - def cpu(self) -> torch.nn.Module: + def cpu(self) -> Module: """Moves all model parameters and buffers to the CPU. Returns: Module: self @@ -116,7 +117,7 @@ def cpu(self) -> torch.nn.Module: self._device = torch.device('cpu') return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: + def type(self, dst_type: Union[str, torch.dtype]) -> Module: """Casts all parameters and buffers to :attr:`dst_type`. Arguments: @@ -128,7 +129,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> torch.nn.Module: self._dtype = dst_type return super().type(dst_type=dst_type) - def float(self) -> torch.nn.Module: + def float(self) -> Module: """Casts all floating point parameters and buffers to float datatype. Returns: @@ -137,7 +138,7 @@ def float(self) -> torch.nn.Module: self._dtype = torch.float return super().float() - def double(self) -> torch.nn.Module: + def double(self) -> Module: """Casts all floating point parameters and buffers to ``double`` datatype. Returns: @@ -146,7 +147,7 @@ def double(self) -> torch.nn.Module: self._dtype = torch.double return super().double() - def half(self) -> torch.nn.Module: + def half(self) -> Module: """Casts all floating point parameters and buffers to ``half`` datatype. Returns: diff --git a/requirements-extra.txt b/requirements-extra.txt index 30bc84ab5190b..0fcd2f8a1bd92 100644 --- a/requirements-extra.txt +++ b/requirements-extra.txt @@ -9,4 +9,6 @@ trains>=0.14.1 matplotlib>=3.1.1 # no need to install with [pytorch] as pytorch is already installed and torchvision is required only for Horovod examples horovod>=0.19.1 -omegaconf==2.0.0 \ No newline at end of file +omegaconf>=2.0.0 +# scipy>=0.13.3 +scikit-learn>=0.20.0 diff --git a/requirements.txt b/requirements.txt index 0aa44aae24f4c..62e723574a9bc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # the default package dependencies +numpy>=1.15 # because some BLAS compilation issues tqdm>=4.41.0 -numpy>=1.16.4 torch>=1.3 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py diff --git a/tests/metrics/test_sklearn_metrics.py b/tests/metrics/test_sklearn_metrics.py new file mode 100644 index 0000000000000..e075330d60a3c --- /dev/null +++ b/tests/metrics/test_sklearn_metrics.py @@ -0,0 +1,86 @@ +import numbers +from collections import Mapping, Sequence +from functools import partial + +import numpy as np +import pytest +import torch +from sklearn.metrics import (accuracy_score, average_precision_score, auc, confusion_matrix, f1_score, + fbeta_score, precision_score, recall_score, precision_recall_curve, roc_curve, + roc_auc_score) + +from pytorch_lightning.metrics.converters import _convert_to_numpy +from pytorch_lightning.metrics.sklearn import (Accuracy, AveragePrecision, AUC, ConfusionMatrix, F1, FBeta, + Precision, Recall, PrecisionRecallCurve, ROC, AUROC) +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +def xy_only(func): + def new_func(*args, **kwargs): + return np.array(func(*args, **kwargs)[:2]) + + return new_func + + +@pytest.mark.parametrize(['metric_class', 'sklearn_func', 'inputs'], [ + pytest.param(Accuracy(), accuracy_score, + {'y_pred': torch.randint(low=0, high=10, size=(128,)), + 'y_true': torch.randint(low=0, high=10, size=(128,))}, id='Accuracy'), + pytest.param(AUC(), auc, {'x': torch.arange(10, dtype=torch.float) / 10, + 'y': torch.tensor([0.2, 0.2, 0.2, 0.2, 0.2, + 0.2, 0.3, 0.5, 0.6, 0.7])}, id='AUC'), + pytest.param(AveragePrecision(), average_precision_score, + {'y_score': torch.randint(2, size=(128,)), + 'y_true': torch.randint(2, size=(128,))}, id='AveragePrecision'), + pytest.param(ConfusionMatrix(), confusion_matrix, + {'y_pred': torch.randint(10, size=(128,)), + 'y_true': torch.randint(10, size=(128,))}, id='ConfusionMatrix'), + pytest.param(F1(average='macro'), partial(f1_score, average='macro'), + {'y_pred': torch.randint(10, size=(128,)), + 'y_true': torch.randint(10, size=(128,))}, id='F1'), + pytest.param(FBeta(beta=0.5, average='macro'), partial(fbeta_score, beta=0.5, average='macro'), + {'y_pred': torch.randint(10, size=(128,)), + 'y_true': torch.randint(10, size=(128,))}, id='FBeta'), + pytest.param(Precision(average='macro'), partial(precision_score, average='macro'), + {'y_pred': torch.randint(10, size=(128,)), + 'y_true': torch.randint(10, size=(128,))}, id='Precision'), + pytest.param(Recall(average='macro'), partial(recall_score, average='macro'), + {'y_pred': torch.randint(10, size=(128,)), + 'y_true': torch.randint(10, size=(128,))}, id='Recall'), + pytest.param(PrecisionRecallCurve(), xy_only(precision_recall_curve), + {'probas_pred': torch.rand(size=(128,)), + 'y_true': torch.randint(2, size=(128,))}, id='PrecisionRecallCurve'), + pytest.param(ROC(), xy_only(roc_curve), + {'y_score': torch.rand(size=(128,)), + 'y_true': torch.randint(2, size=(128,))}, id='ROC'), + pytest.param(AUROC(), roc_auc_score, + {'y_score': torch.rand(size=(128,)), + 'y_true': torch.randint(2, size=(128,))}, id='AUROC'), +]) +def test_sklearn_metric(metric_class, sklearn_func, inputs: dict): + numpy_inputs = apply_to_collection( + inputs, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + + sklearn_result = sklearn_func(**numpy_inputs) + lightning_result = metric_class(**inputs) + + sklearn_result = apply_to_collection( + sklearn_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + + lightning_result = apply_to_collection( + lightning_result, (torch.Tensor, np.ndarray, numbers.Number), _convert_to_numpy) + + assert isinstance(lightning_result, type(sklearn_result)) + + if isinstance(lightning_result, np.ndarray): + assert np.allclose(lightning_result, sklearn_result) + elif isinstance(lightning_result, Mapping): + for key in lightning_result.keys(): + assert np.allclose(lightning_result[key], sklearn_result[key]) + + elif isinstance(lightning_result, Sequence): + for val_lightning, val_sklearn in zip(lightning_result, sklearn_result): + assert np.allclose(val_lightning, val_sklearn) + + else: + raise TypeError diff --git a/tests/requirements.txt b/tests/requirements.txt index fdf2e83337acb..2945bc5f968d2 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -8,4 +8,4 @@ flake8-black check-manifest twine==1.13.0 black==19.10b0 -pre-commit>=1.21.0 +pre-commit>=1.0