diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index f8d7a2ffe3a239..3961586f4946af 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -20,8 +20,8 @@ import torch from torch import Tensor +from torchmetrics import Metric -from pytorch_lightning.metrics import Metric from pytorch_lightning.utilities.distributed import sync_ddp_if_available diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 49d5f7aa472bfc..13e2ad8b4d3fa3 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.accuracy import _accuracy_compute, _accuracy_update -from pytorch_lightning.metrics.metric import Metric class Accuracy(Metric): diff --git a/pytorch_lightning/metrics/classification/auc.py b/pytorch_lightning/metrics/classification/auc.py index 6c5a29173d20a0..76c1959a8603a7 100644 --- a/pytorch_lightning/metrics/classification/auc.py +++ b/pytorch_lightning/metrics/classification/auc.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auc import _auc_compute, _auc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/auroc.py b/pytorch_lightning/metrics/classification/auroc.py index ece2452938b5b0..9e0771e41590d0 100644 --- a/pytorch_lightning/metrics/classification/auroc.py +++ b/pytorch_lightning/metrics/classification/auroc.py @@ -15,9 +15,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.auroc import _auroc_compute, _auroc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index f9c7bde158383f..adcdd86ed1ca80 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.average_precision import _average_precision_compute, _average_precision_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index c3defc82bc92db..112fb4940e6e2d 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_compute, _confusion_matrix_update -from pytorch_lightning.metrics.metric import Metric class ConfusionMatrix(Metric): diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 9a580e02cf8ae8..4b6c1fdec268d5 100644 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -14,9 +14,9 @@ from typing import Any, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.f_beta import _fbeta_compute, _fbeta_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/hamming_distance.py b/pytorch_lightning/metrics/classification/hamming_distance.py index 1737b25e5455c6..78c2bc192c460e 100644 --- a/pytorch_lightning/metrics/classification/hamming_distance.py +++ b/pytorch_lightning/metrics/classification/hamming_distance.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.hamming_distance import _hamming_distance_compute, _hamming_distance_update -from pytorch_lightning.metrics.metric import Metric class HammingDistance(Metric): diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 9c6c4421cbb7c8..f6027061d041de 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -14,12 +14,12 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.precision_recall_curve import ( _precision_recall_curve_compute, _precision_recall_curve_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 9452d59fb9e767..6ded27d01a38bf 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -14,9 +14,9 @@ from typing import Any, List, Optional, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.roc import _roc_compute, _roc_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/classification/stat_scores.py b/pytorch_lightning/metrics/classification/stat_scores.py index 3d956030a61403..65e381c1de07e7 100644 --- a/pytorch_lightning/metrics/classification/stat_scores.py +++ b/pytorch_lightning/metrics/classification/stat_scores.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional, Tuple import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_compute, _stat_scores_update -from pytorch_lightning.metrics.metric import Metric class StatScores(Metric): diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index 5961714209d400..975b8280f77d5c 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -21,10 +21,9 @@ class CompositionalMetric(__CompositionalMetric): - r""" - This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. + """ + .. deprecated:: + Use :class:`torchmetrics.metric.CompositionalMetric`. Will be removed in v1.5.0. """ def __init__( @@ -34,7 +33,7 @@ def __init__( metric_b: Union[Metric, int, float, torch.Tensor, None], ): rank_zero_warn( - "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." - " It will be removed in v1.5.0", DeprecationWarning + "This `CompositionalMetric` was deprecated since v1.3.0 in favor of" + " `torchmetrics.metric.CompositionalMetric`. It will be removed in v1.5.0", DeprecationWarning ) super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index b7386c3b4a8203..b077532708dcdb 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -13,17 +13,17 @@ # limitations under the License. from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from torchmetrics import Metric as __Metric -from torchmetrics import MetricCollection as __MetricCollection +from torchmetrics import Metric as _Metric +from torchmetrics.collections import MetricCollection as _MetricCollection +from pytorch_lightning.utilities.deprecation import deprecated from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(__Metric): +class Metric(_Metric): r""" - This implementation refers to :class:`~torchmetrics.Metric`. - - .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. + .. deprecated:: + Use :class:`torchmetrics.Metric`. Will be removed in v1.5.0. """ def __init__( @@ -45,16 +45,12 @@ def __init__( ) -class MetricCollection(__MetricCollection): - r""" - This implementation refers to :class:`~torchmetrics.MetricCollection`. - - .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. +class MetricCollection(_MetricCollection): + """ + .. deprecated:: + Use :class:`torchmetrics.MetricCollection`. Will be removed in v1.5.0. """ + @deprecated(target=_MetricCollection, ver_deprecate="1.3.0", ver_remove="1.5.0") def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - rank_zero_warn( - "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." - " It will be removed in v1.5.0", DeprecationWarning - ) - super().__init__(metrics=metrics) + pass diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 467ac72cc3eda2..a72a67ddb02ce0 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.explained_variance import ( _explained_variance_compute, _explained_variance_update, ) -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index ca184daf736b84..484ccbe83284e9 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_absolute_error import ( _mean_absolute_error_compute, _mean_absolute_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanAbsoluteError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 09f275ded86385..c26371514e7cd6 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_error import ( _mean_squared_error_compute, _mean_squared_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredError(Metric): diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 18105e687b0b1f..caaf09a3663ffa 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -14,12 +14,12 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.mean_squared_log_error import ( _mean_squared_log_error_compute, _mean_squared_log_error_update, ) -from pytorch_lightning.metrics.metric import Metric class MeanSquaredLogError(Metric): diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index b07941f010c3ac..6f3e2a92d2937d 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -14,10 +14,10 @@ from typing import Any, Optional, Sequence, Tuple, Union import torch +from torchmetrics import Metric from pytorch_lightning import utilities from pytorch_lightning.metrics.functional.psnr import _psnr_compute, _psnr_update -from pytorch_lightning.metrics.metric import Metric class PSNR(Metric): diff --git a/pytorch_lightning/metrics/regression/r2score.py b/pytorch_lightning/metrics/regression/r2score.py index 77f6c1363a5665..f7a11caf65703e 100644 --- a/pytorch_lightning/metrics/regression/r2score.py +++ b/pytorch_lightning/metrics/regression/r2score.py @@ -14,9 +14,9 @@ from typing import Any, Callable, Optional import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.r2score import _r2score_compute, _r2score_update -from pytorch_lightning.metrics.metric import Metric class R2Score(Metric): diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index 09b55fb2bb4566..a3bbab938ffad9 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -14,9 +14,9 @@ from typing import Any, Optional, Sequence import torch +from torchmetrics import Metric from pytorch_lightning.metrics.functional.ssim import _ssim_compute, _ssim_update -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 82f328a9274855..554f1d3faf9ed9 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -15,8 +15,7 @@ from typing import Any import torch - -from pytorch_lightning.metrics.metric import Metric +from torchmetrics import Metric class MetricsHolder: diff --git a/pytorch_lightning/utilities/deprecation.py b/pytorch_lightning/utilities/deprecation.py new file mode 100644 index 00000000000000..3e2034c6a0453d --- /dev/null +++ b/pytorch_lightning/utilities/deprecation.py @@ -0,0 +1,73 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +from functools import wraps +from typing import Any, Callable, List, Tuple + +from pytorch_lightning.utilities import rank_zero_warn + + +def get_func_arguments_and_types(func: Callable) -> List[Tuple[str, Tuple, Any]]: + """Parse function arguments, types and default values + + Example: + >>> get_func_arguments_and_types(get_func_arguments_and_types) + [('func', typing.Callable, )] + """ + func_default_params = inspect.signature(func).parameters + name_type_default = [] + for arg in func_default_params: + arg_type = func_default_params[arg].annotation + arg_default = func_default_params[arg].default + name_type_default.append((arg, arg_type, arg_default)) + return name_type_default + + +def deprecated(target: Callable, ver_deprecate: str = "", ver_remove: str = "") -> Callable: + """ + Decorate a function or class ``__init__`` with warning message + and pass all arguments directly to the target class/method. + """ + + def inner_function(func): + + @wraps(func) + def wrapped_fn(*args, **kwargs): + is_class = inspect.isclass(target) + target_func = target.__init__ if is_class else target + # warn user only once in lifetime + if not getattr(inner_function, 'warned', False): + target_str = f'{target.__module__}.{target.__name__}' + func_name = func.__qualname__.split('.')[-2] if is_class else func.__name__ + rank_zero_warn( + f"The `{func_name}` was deprecated since v{ver_deprecate} in favor of `{target_str}`." + f" It will be removed in v{ver_remove}.", DeprecationWarning + ) + inner_function.warned = True + + if args: # in case any args passed move them to kwargs + # parse only the argument names + cls_arg_names = [arg[0] for arg in get_func_arguments_and_types(func)] + # convert args to kwargs + kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + + target_args = [arg[0] for arg in get_func_arguments_and_types(target_func)] + assert all(arg in target_args for arg in kwargs), \ + "Failed mapping, arguments missing in target func: %s" % [arg not in target_args for arg in kwargs] + # all args were already moved to kwargs + return target_func(**kwargs) + + return wrapped_fn + + return inner_function diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py index b2fa4f69f74b99..7c8c9ad2964167 100644 --- a/tests/deprecated_api/test_remove_1-5_metrics.py +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -16,6 +16,7 @@ import pytest import torch +from pytorch_lightning.metrics import Accuracy, MetricCollection from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot @@ -34,3 +35,14 @@ def test_v1_5_0_metrics_utils(): x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) with pytest.deprecated_call(match="It will be removed in v1.5.0"): assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) + + +def test_v1_5_0_metrics_collection(): + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + with pytest.deprecated_call( + match="The `MetricCollection` was deprecated since v1.3.0 in favor" + " of `torchmetrics.collections.MetricCollection`. It will be removed in v1.5.0" + ): + metrics = MetricCollection([Accuracy()]) + assert metrics(preds, target) == {'Accuracy': torch.Tensor([0.1250])[0]} diff --git a/tests/metrics/test_metric_lightning.py b/tests/metrics/test_metric_lightning.py index 895305fa9da7e6..2e040a881d49fe 100644 --- a/tests/metrics/test_metric_lightning.py +++ b/tests/metrics/test_metric_lightning.py @@ -1,7 +1,7 @@ import torch +from torchmetrics import Metric, MetricCollection from pytorch_lightning import Trainer -from pytorch_lightning.metrics import Metric, MetricCollection from tests.helpers.boring_model import BoringModel diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 4bd6608ce3fcf4..f1f17d0624936f 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -8,8 +8,7 @@ import pytest import torch from torch.multiprocessing import Pool, set_start_method - -from pytorch_lightning.metrics import Metric +from torchmetrics import Metric try: set_start_method("spawn") diff --git a/tests/utilities/test_deprecation.py b/tests/utilities/test_deprecation.py new file mode 100644 index 00000000000000..7c653c07ad168d --- /dev/null +++ b/tests/utilities/test_deprecation.py @@ -0,0 +1,37 @@ +import pytest + +from pytorch_lightning.utilities.deprecation import deprecated +from tests.helpers.utils import no_warning_call + + +def my_sum(a, b=3): + return a + b + + +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep_sum(a, b): + pass + + +@deprecated(target=my_sum, ver_deprecate="0.1", ver_remove="0.5") +def dep2_sum(a, b): + pass + + +def test_deprecated_func(): + with pytest.deprecated_call( + match='The `dep_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep_sum(2, b=5) == 7 + + # check that the warning is raised only once per function + with no_warning_call(DeprecationWarning): + assert dep_sum(2, b=5) == 7 + + # and does not affect other functions + with pytest.deprecated_call( + match='The `dep2_sum` was deprecated since v0.1 in favor of `tests.utilities.test_deprecation.my_sum`.' + ' It will be removed in v0.5.' + ): + assert dep2_sum(2) == 5