Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune metric: helpers and inputs 3/n #6547

Merged
merged 13 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),

)

Expand Down
535 changes: 0 additions & 535 deletions pytorch_lightning/metrics/classification/helpers.py

This file was deleted.

4 changes: 2 additions & 2 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Optional, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType


def _accuracy_update(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from typing import Optional, Sequence, Tuple

import torch
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional.auc import auc
from pytorch_lightning.metrics.functional.roc import roc
from pytorch_lightning.utilities import LightningEnum
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/metrics/functional/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing import Optional

import torch
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.utilities import rank_zero_warn


Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/hamming_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Tuple, Union

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from torchmetrics.classification.checks import _input_format_classification


def _hamming_distance_update(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/functional/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from typing import Optional

import torch
from torchmetrics.classification.stat_scores import _reduce_stat_scores

from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores
from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update
from pytorch_lightning.utilities import rank_zero_warn

Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
from typing import Optional, Tuple

import torch

from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from torchmetrics.classification.checks import _input_format_classification


def _del_column(tensor: torch.Tensor, index: int):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable:
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.
"""

@wraps(fn)
def insert_env_defaults(self, *args, **kwargs):
cls = self.__class__ # get the class
Expand Down
3 changes: 2 additions & 1 deletion tests/metrics/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import pytest
import torch
from sklearn.metrics import accuracy_score as sk_accuracy
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from pytorch_lightning.metrics.functional import accuracy
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_hamming_distance.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import pytest
import torch
from sklearn.metrics import hamming_loss as sk_hamming_loss
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import HammingDistance
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import hamming_distance
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
3 changes: 2 additions & 1 deletion tests/metrics/classification/test_inputs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
import torch
from torch import rand, randint
from torchmetrics.classification.checks import _input_format_classification
from torchmetrics.utilities.data import select_topk, to_onehot
from torchmetrics.utilities.enums import DataType

from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType
from tests.metrics.classification.inputs import _input_binary as _bin
from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob
from tests.metrics.classification.inputs import _input_multiclass as _mc
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest
import torch
from sklearn.metrics import precision_score, recall_score
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import Metric, Precision, Recall
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import precision, precision_recall, recall
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob
from tests.metrics.classification.inputs import _input_multiclass as _input_mcls
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import pytest
import torch
from sklearn.metrics import multilabel_confusion_matrix
from torchmetrics.classification.checks import _input_format_classification

from pytorch_lightning.metrics import StatScores
from pytorch_lightning.metrics.classification.helpers import _input_format_classification
from pytorch_lightning.metrics.functional import stat_scores
from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass
from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def test_lr_find_with_bs_scale(tmpdir):
""" Test that lr_find runs with batch_size_scaling """

class BoringModelTune(BoringModel):

def __init__(self, learning_rate=0.1, batch_size=2):
super().__init__()
self.save_hyperparameters()
Expand Down