Skip to content

Commit

Permalink
Prune metric: helpers and inputs 3/n (#6547)
Browse files Browse the repository at this point in the history
* _basic_input_validation

* _check_shape_and_type_consistency

* _check_num_classes_binary

* _check_num_classes_mc

* _check_num_classes_ml

* _check_top_k

* _check_classification_inputs

* _input_format_classification

* _reduce_stat_scores

* DataType

* rest

* flake8

* chlog
  • Loading branch information
Borda authored Mar 16, 2021
1 parent 0f07eaf commit a312219
Show file tree
Hide file tree
Showing 15 changed files with 20 additions and 549 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),

)

Expand Down
535 changes: 0 additions & 535 deletions pytorch_lightning/metrics/classification/helpers.py

This file was deleted.

4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Optional, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType


def _accuracy_update(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from typing import Optional, Sequence, Tuple

import torch
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing import Optional

import torch
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple, Union

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from torchmetrics.classification.checks import _input_format_classification


def _hamming_distance_update(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Optional

import torch
from torchmetrics.classification.stat_scores import _reduce_stat_scores

from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores
from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update
from pytorch_lightning.utilities import rank_zero_warn

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Optional, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from torchmetrics.classification.checks import _input_format_classification


def _del_column(tensor: torch.Tensor, index: int):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable:
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.
"""

@wraps(fn)
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
Expand Down
3 changes: 2 additions & 1 deletion tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import pytest
import torch
from sklearn.metrics import accuracy_score as sk_accuracy
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
import torch
from sklearn.metrics import hamming_loss as sk_hamming_loss
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import HammingDistance
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import hamming_distance
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
3 changes: 2 additions & 1 deletion tests/metrics/classification/test_inputs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import torch
from torch import rand, randint
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.data import select_topk, to_onehot
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from tests.metrics.classification.inputs import _input_binary as _bin
from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob
from tests.metrics.classification.inputs import _input_multiclass as _mc
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import Metric, Precision, Recall
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest
import torch
from sklearn.metrics import multilabel_confusion_matrix
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import StatScores
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import stat_scores
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def test_lr_find_with_bs_scale(tmpdir):
""" Test that lr_find runs with batch_size_scaling """

class BoringModelTune(BoringModel):

def __init__(self, learning_rate=0.1, batch_size=2):
super().__init__()
self.save_hyperparameters()
Expand Down

0 comments on commit a312219

Please sign in to comment.