From 6453091b8ab3713e2d58bad7acc9a4345dc5d10b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 20:28:18 +0100 Subject: [PATCH] Prune metrics base classes 2/n (#6530) * base class * extensions * chlog * _stable_1d_sort * _check_same_shape * _input_format_classification_one_hot * utils * to_onehot * select_topk * to_categorical * get_num_classes * reduce * class_reduce * tests --- CHANGELOG.md | 6 +- .../basic_examples/conv_sequential_example.py | 2 +- pytorch_lightning/accelerators/gpu.py | 2 +- .../metrics/classification/helpers.py | 2 +- pytorch_lightning/metrics/compositional.py | 100 +-- pytorch_lightning/metrics/functional/auc.py | 3 +- .../metrics/functional/classification.py | 3 +- .../metrics/functional/explained_variance.py | 3 +- .../metrics/functional/f_beta.py | 4 +- pytorch_lightning/metrics/functional/iou.py | 3 +- .../metrics/functional/mean_absolute_error.py | 3 +- .../metrics/functional/mean_relative_error.py | 3 +- .../metrics/functional/mean_squared_error.py | 3 +- .../functional/mean_squared_log_error.py | 3 +- pytorch_lightning/metrics/functional/psnr.py | 8 +- .../metrics/functional/r2score.py | 2 +- pytorch_lightning/metrics/functional/ssim.py | 4 +- pytorch_lightning/metrics/metric.py | 614 +----------------- pytorch_lightning/metrics/utils.py | 315 ++------- .../plugins/training_type/rpc.py | 2 +- .../plugins/training_type/rpc_sequential.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 2 +- .../deprecated_api/test_remove_1-5_metrics.py | 36 + tests/metrics/classification/test_inputs.py | 2 +- .../metrics/functional/test_classification.py | 2 +- tests/metrics/functional/test_reduction.py | 3 +- tests/metrics/test_composition.py | 510 --------------- tests/metrics/test_ddp.py | 71 -- tests/metrics/test_metric.py | 395 ----------- 29 files changed, 187 insertions(+), 1921 deletions(-) create mode 100644 tests/deprecated_api/test_remove_1-5_metrics.py delete mode 100644 tests/metrics/test_composition.py delete mode 100644 tests/metrics/test_ddp.py delete mode 100644 tests/metrics/test_metric.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 085335e4ca090..f5dcb20375137 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -65,7 +65,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945)) -- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505)) +- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), + + [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), + +) ### Removed diff --git a/pl_examples/basic_examples/conv_sequential_example.py b/pl_examples/basic_examples/conv_sequential_example.py index 95ee66e1d5a14..f3d9469144f50 100644 --- a/pl_examples/basic_examples/conv_sequential_example.py +++ b/pl_examples/basic_examples/conv_sequential_example.py @@ -27,11 +27,11 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +from torchmetrics.functional import accuracy import pytorch_lightning as pl from pl_examples import cli_lightning_logo from pytorch_lightning import Trainer -from torchmetrics.functional import accuracy from pytorch_lightning.plugins import RPCSequentialPlugin from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE diff --git a/pytorch_lightning/accelerators/gpu.py b/pytorch_lightning/accelerators/gpu.py index af9ce25f902b3..5c5dc5cc6f531 100644 --- a/pytorch_lightning/accelerators/gpu.py +++ b/pytorch_lightning/accelerators/gpu.py @@ -1,6 +1,6 @@ import logging import os -from typing import TYPE_CHECKING, Any +from typing import Any, TYPE_CHECKING import torch diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 58d3142de72f2..a91150799d5a1 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -15,8 +15,8 @@ import numpy as np import torch +from torchmetrics.utilities.data import select_topk, to_onehot -from pytorch_lightning.metrics.utils import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum diff --git a/pytorch_lightning/metrics/compositional.py b/pytorch_lightning/metrics/compositional.py index df98d16a3ef7e..5961714209d40 100644 --- a/pytorch_lightning/metrics/compositional.py +++ b/pytorch_lightning/metrics/compositional.py @@ -1,14 +1,30 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, Union import torch +from torchmetrics import Metric +from torchmetrics.metric import CompositionalMetric as __CompositionalMetric -from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import rank_zero_warn -class CompositionalMetric(Metric): - """Composition of two metrics with a specific operator - which will be executed upon metric's compute +class CompositionalMetric(__CompositionalMetric): + r""" + This implementation refers to :class:`~torchmetrics.metric.CompositionalMetric`. + .. warning:: This metric is deprecated, use ``torchmetrics.metric.CompositionalMetric``. Will be removed in v1.5.0. """ def __init__( @@ -17,76 +33,8 @@ def __init__( metric_a: Union[Metric, int, float, torch.Tensor], metric_b: Union[Metric, int, float, torch.Tensor, None], ): - """ - - Args: - operator: the operator taking in one (if metric_b is None) - or two arguments. Will be applied to outputs of metric_a.compute() - and (optionally if metric_b is not None) metric_b.compute() - metric_a: first metric whose compute() result is the first argument of operator - metric_b: second metric whose compute() result is the second argument of operator. - For operators taking in only one input, this should be None - """ - super().__init__() - - self.op = operator - - if isinstance(metric_a, torch.Tensor): - self.register_buffer("metric_a", metric_a) - else: - self.metric_a = metric_a - - if isinstance(metric_b, torch.Tensor): - self.register_buffer("metric_b", metric_b) - else: - self.metric_b = metric_b - - def _sync_dist(self, dist_sync_fn=None): - # No syncing required here. syncing will be done in metric_a and metric_b - pass - - def update(self, *args, **kwargs): - if isinstance(self.metric_a, Metric): - self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs)) - - if isinstance(self.metric_b, Metric): - self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs)) - - def compute(self): - - # also some parsing for kwargs? - if isinstance(self.metric_a, Metric): - val_a = self.metric_a.compute() - else: - val_a = self.metric_a - - if isinstance(self.metric_b, Metric): - val_b = self.metric_b.compute() - else: - val_b = self.metric_b - - if val_b is None: - return self.op(val_a) - - return self.op(val_a, val_b) - - def reset(self): - if isinstance(self.metric_a, Metric): - self.metric_a.reset() - - if isinstance(self.metric_b, Metric): - self.metric_b.reset() - - def persistent(self, mode: bool = False): - if isinstance(self.metric_a, Metric): - self.metric_a.persistent(mode=mode) - if isinstance(self.metric_b, Metric): - self.metric_b.persistent(mode=mode) - - def __repr__(self): - repr_str = ( - self.__class__.__name__ - + f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)" + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning ) - - return repr_str + super().__init__(operator=operator, metric_a=metric_a, metric_b=metric_b) diff --git a/pytorch_lightning/metrics/functional/auc.py b/pytorch_lightning/metrics/functional/auc.py index bae404120b48c..cc5c9cf889b7a 100644 --- a/pytorch_lightning/metrics/functional/auc.py +++ b/pytorch_lightning/metrics/functional/auc.py @@ -14,8 +14,7 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _stable_1d_sort +from torchmetrics.utilities.data import _stable_1d_sort def _auc_update(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index d989bc503b62c..f145b7c1a5c67 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -15,11 +15,12 @@ from typing import Callable, Optional, Sequence, Tuple import torch +from torchmetrics.utilities import class_reduce, reduce +from torchmetrics.utilities.data import get_num_classes, to_categorical from pytorch_lightning.metrics.functional.auc import auc as __auc from pytorch_lightning.metrics.functional.auroc import auroc as __auroc from pytorch_lightning.metrics.functional.iou import iou as __iou -from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/functional/explained_variance.py b/pytorch_lightning/metrics/functional/explained_variance.py index 617d800c754e3..fa8d43c06c7ef 100644 --- a/pytorch_lightning/metrics/functional/explained_variance.py +++ b/pytorch_lightning/metrics/functional/explained_variance.py @@ -14,8 +14,7 @@ from typing import Sequence, Tuple, Union import torch - -from pytorch_lightning.metrics.utils import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape def _explained_variance_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/pytorch_lightning/metrics/functional/f_beta.py b/pytorch_lightning/metrics/functional/f_beta.py index debb6c8285fc9..5be4786297b65 100644 --- a/pytorch_lightning/metrics/functional/f_beta.py +++ b/pytorch_lightning/metrics/functional/f_beta.py @@ -14,8 +14,8 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _input_format_classification_one_hot, class_reduce +from torchmetrics.utilities import class_reduce +from torchmetrics.utilities.checks import _input_format_classification_one_hot def _fbeta_update( diff --git a/pytorch_lightning/metrics/functional/iou.py b/pytorch_lightning/metrics/functional/iou.py index 7b6851b5cebd0..0f8152d314848 100644 --- a/pytorch_lightning/metrics/functional/iou.py +++ b/pytorch_lightning/metrics/functional/iou.py @@ -14,9 +14,10 @@ from typing import Optional import torch +from torchmetrics.utilities import reduce +from torchmetrics.utilities.data import get_num_classes from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update -from pytorch_lightning.metrics.utils import get_num_classes, reduce def _iou_from_confmat( diff --git a/pytorch_lightning/metrics/functional/mean_absolute_error.py b/pytorch_lightning/metrics/functional/mean_absolute_error.py index 671368ba240f9..2bd8f125ecb9e 100644 --- a/pytorch_lightning/metrics/functional/mean_absolute_error.py +++ b/pytorch_lightning/metrics/functional/mean_absolute_error.py @@ -14,8 +14,7 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape def _mean_absolute_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: diff --git a/pytorch_lightning/metrics/functional/mean_relative_error.py b/pytorch_lightning/metrics/functional/mean_relative_error.py index eedaea1a26a4f..bfe5eb6b847d7 100644 --- a/pytorch_lightning/metrics/functional/mean_relative_error.py +++ b/pytorch_lightning/metrics/functional/mean_relative_error.py @@ -14,8 +14,7 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape def _mean_relative_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: diff --git a/pytorch_lightning/metrics/functional/mean_squared_error.py b/pytorch_lightning/metrics/functional/mean_squared_error.py index 2cdd4ea679043..66c0aadef0651 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_error.py @@ -14,8 +14,7 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape def _mean_squared_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: diff --git a/pytorch_lightning/metrics/functional/mean_squared_log_error.py b/pytorch_lightning/metrics/functional/mean_squared_log_error.py index 45c255eb61d78..baec63c7248f2 100644 --- a/pytorch_lightning/metrics/functional/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/functional/mean_squared_log_error.py @@ -14,8 +14,7 @@ from typing import Tuple import torch - -from pytorch_lightning.metrics.utils import _check_same_shape +from torchmetrics.utilities.checks import _check_same_shape def _mean_squared_log_error_update(preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, int]: diff --git a/pytorch_lightning/metrics/functional/psnr.py b/pytorch_lightning/metrics/functional/psnr.py index 2b8757ead9b6e..0b50ea092b7fa 100644 --- a/pytorch_lightning/metrics/functional/psnr.py +++ b/pytorch_lightning/metrics/functional/psnr.py @@ -14,9 +14,9 @@ from typing import Optional, Tuple, Union import torch +from torchmetrics.utilities import reduce -from pytorch_lightning import utilities -from pytorch_lightning.metrics import utils +from pytorch_lightning.utilities import rank_zero_warn def _psnr_compute( @@ -28,7 +28,7 @@ def _psnr_compute( ) -> torch.Tensor: psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs) psnr = psnr_base_e * (10 / torch.log(torch.tensor(base))) - return utils.reduce(psnr, reduction=reduction) + return reduce(psnr, reduction=reduction) def _psnr_update(preds: torch.Tensor, @@ -97,7 +97,7 @@ def psnr( """ if dim is None and reduction != 'elementwise_mean': - utilities.rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') + rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.') if data_range is None: if dim is not None: diff --git a/pytorch_lightning/metrics/functional/r2score.py b/pytorch_lightning/metrics/functional/r2score.py index d551c03106ba0..d3f1090564a88 100644 --- a/pytorch_lightning/metrics/functional/r2score.py +++ b/pytorch_lightning/metrics/functional/r2score.py @@ -14,8 +14,8 @@ from typing import Tuple import torch +from torchmetrics.utilities.checks import _check_same_shape -from pytorch_lightning.metrics.utils import _check_same_shape from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/functional/ssim.py b/pytorch_lightning/metrics/functional/ssim.py index a9d01ea47192e..4899a3ad3be4d 100644 --- a/pytorch_lightning/metrics/functional/ssim.py +++ b/pytorch_lightning/metrics/functional/ssim.py @@ -15,8 +15,8 @@ import torch from torch.nn import functional as F - -from pytorch_lightning.metrics.utils import _check_same_shape, reduce +from torchmetrics.utilities import reduce +from torchmetrics.utilities.checks import _check_same_shape def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device): diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 3ff3039cb99b1..145a13a251250 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -11,52 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import inspect -from abc import ABC, abstractmethod -from collections.abc import Sequence -from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import torch -from torch import nn +from torchmetrics import Metric as __Metric +from torchmetrics import MetricCollection as __MetricCollection -from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum -from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors +from pytorch_lightning.utilities.distributed import rank_zero_warn -class Metric(nn.Module, ABC): - """ - Base class for all metrics present in the Metrics API. - - Implements ``add_state()``, ``forward()``, ``reset()`` and a few other things to - handle distributed synchronization and per-step metric computation. - - Override ``update()`` and ``compute()`` functions to implement your own metric. Use - ``add_state()`` to register metric state variables which keep track of state on each - call of ``update()`` and are synchronized across processes when ``compute()`` is called. - - Note: - Metric state variables can either be ``torch.Tensors`` or an empty list which can we used - to store `torch.Tensors``. - - Note: - Different metrics only override ``update()`` and not ``forward()``. A call to ``update()`` - is valid, but it won't return the metric value at the current step. A call to ``forward()`` - automatically calls ``update()`` and also returns the metric value at the current step. +class Metric(__Metric): + r""" + This implementation refers to :class:`~torchmetrics.Metric`. - Args: - compute_on_step: - Forward only calls ``update()`` and returns None if this is set to False. default: True - dist_sync_on_step: - Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. - process_group: - Specify the process group on which synchronization is called. default: None (which selects the entire world) - dist_sync_fn: - Callback that performs the allgather operation on the metric state. When `None`, DDP - will be used to perform the allgather. default: None + .. warning:: This metric is deprecated, use ``torchmetrics.Metric``. Will be removed in v1.5.0. """ def __init__( @@ -66,559 +33,28 @@ def __init__( process_group: Optional[Any] = None, dist_sync_fn: Callable = None, ): - super().__init__() - - self.dist_sync_on_step = dist_sync_on_step - self.compute_on_step = compute_on_step - self.process_group = process_group - self.dist_sync_fn = dist_sync_fn - self._to_sync = True - - self._update_signature = inspect.signature(self.update) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - self._computed = None - self._forward_cache = None - - # initialize state - self._defaults = {} - self._persistent = {} - self._reductions = {} - - def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False - ): - """ - Adds metric state variable. Only used by subclasses. - - Args: - name: The name of the state variable. The variable will then be accessible at ``self.name``. - default: Default value of the state; can either be a ``torch.Tensor`` or an empty list. The state will be - reset to this value when ``self.reset()`` is called. - dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode. - If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``, - and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction - only makes sense if the state is a list, and not a tensor. The user can also pass a custom - function in this parameter. - persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. - Default is ``False``. - - Note: - Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. - However, there won't be any reduction function applied to the synchronized metric state. - - The metric states would be synced as follows - - - If the metric state is ``torch.Tensor``, the synced value will be a stacked ``torch.Tensor`` across - the process dimension if the metric state was a ``torch.Tensor``. The original ``torch.Tensor`` metric - state retains dimension and hence the synchronized output will be of shape ``(num_process, ...)``. - - - If the metric state is a ``list``, the synced value will be a ``list`` containing the - combined elements from all processes. - - Note: - When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow - the format discussed in the above note. - - Raises: - ValueError: - If ``default`` is not a ``tensor`` or an ``empty list``. - ValueError: - If ``dist_reduce_fx`` is not callable or one of ``"mean"``, ``"sum"``, ``"cat"``, ``None``. - """ - if ( - not isinstance(default, torch.Tensor) and not isinstance(default, list) # noqa: W503 - or (isinstance(default, list) and len(default) != 0) # noqa: W503 - ): - raise ValueError("state variable must be a tensor or any empty list (where you can append tensors)") - - if dist_reduce_fx == "sum": - dist_reduce_fx = dim_zero_sum - elif dist_reduce_fx == "mean": - dist_reduce_fx = dim_zero_mean - elif dist_reduce_fx == "cat": - dist_reduce_fx = dim_zero_cat - elif dist_reduce_fx is not None and not isinstance(dist_reduce_fx, Callable): - raise ValueError("`dist_reduce_fx` must be callable or one of ['mean', 'sum', 'cat', None]") - - setattr(self, name, default) - - self._defaults[name] = deepcopy(default) - self._persistent[name] = persistent - self._reductions[name] = dist_reduce_fx - - @torch.jit.unused - def forward(self, *args, **kwargs): - """ - Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True. - """ - # add current step - with torch.no_grad(): - self.update(*args, **kwargs) - self._forward_cache = None - - if self.compute_on_step: - self._to_sync = self.dist_sync_on_step - - # save context before switch - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # call reset, update, compute, on single batch - self.reset() - self.update(*args, **kwargs) - self._forward_cache = self.compute() - - # restore context - for attr, val in cache.items(): - setattr(self, attr, val) - self._to_sync = True - self._computed = None - - return self._forward_cache - - def _sync_dist(self, dist_sync_fn=gather_all_tensors): - input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} - output_dict = apply_to_collection( - input_dict, - torch.Tensor, - dist_sync_fn, - group=self.process_group, + rank_zero_warn( + "This `Metric` was deprecated since v1.3.0 in favor of `torchmetrics.Metric`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super(Metric, self).__init__( + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, ) - for attr, reduction_fn in self._reductions.items(): - # pre-processing ops (stack or flatten for inputs) - if isinstance(output_dict[attr][0], torch.Tensor): - output_dict[attr] = torch.stack(output_dict[attr]) - elif isinstance(output_dict[attr][0], list): - output_dict[attr] = _flatten(output_dict[attr]) - - assert isinstance(reduction_fn, (Callable)) or reduction_fn is None - reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr] - setattr(self, attr, reduced) - - def _wrap_update(self, update): - - @functools.wraps(update) - def wrapped_func(*args, **kwargs): - self._computed = None - return update(*args, **kwargs) - - return wrapped_func - - def _wrap_compute(self, compute): - - @functools.wraps(compute) - def wrapped_func(*args, **kwargs): - # return cached value - if self._computed is not None: - return self._computed - - dist_sync_fn = self.dist_sync_fn - if dist_sync_fn is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - # User provided a bool, so we assume DDP if available - dist_sync_fn = gather_all_tensors - - synced = False - if self._to_sync and dist_sync_fn is not None: - # cache prior to syncing - cache = {attr: getattr(self, attr) for attr in self._defaults.keys()} - - # sync - self._sync_dist(dist_sync_fn) - synced = True - - self._computed = compute(*args, **kwargs) - if synced: - # if we synced, restore to cache so that we can continue to accumulate un-synced state - for attr, val in cache.items(): - setattr(self, attr, val) - - return self._computed - - return wrapped_func - - @abstractmethod - def update(self) -> None: # pylint: disable=E0202 - """ - Override this method to update the state variables of your metric class. - """ - pass - - @abstractmethod - def compute(self): # pylint: disable=E0202 - """ - Override this method to compute the final metric value from state variables - synchronized across the distributed backend. - """ - pass - - def reset(self): - """ - This method automatically resets the metric state variables to their default value. - """ - for attr, default in self._defaults.items(): - current_val = getattr(self, attr) - if isinstance(default, torch.Tensor): - setattr(self, attr, deepcopy(default).to(current_val.device)) - else: - setattr(self, attr, deepcopy(default)) - - def clone(self): - """ Make a copy of the metric """ - return deepcopy(self) - - def __getstate__(self): - # ignore update and compute functions for pickling - return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} - - def __setstate__(self, state): - # manually restore update and compute functions for pickling - self.__dict__.update(state) - self.update = self._wrap_update(self.update) - self.compute = self._wrap_compute(self.compute) - - def _apply(self, fn): - """Overwrite _apply function such that we can also move metric states - to the correct device when `.to`, `.cuda`, etc methods are called - """ - self = super()._apply(fn) - # Also apply fn to metric states - for key in self._defaults.keys(): - current_val = getattr(self, key) - if isinstance(current_val, torch.Tensor): - setattr(self, key, fn(current_val)) - elif isinstance(current_val, Sequence): - setattr(self, key, [fn(cur_v) for cur_v in current_val]) - else: - raise TypeError( - "Expected metric state to be either a torch.Tensor" - f"or a list of torch.Tensor, but encountered {current_val}" - ) - return self - - def persistent(self, mode: bool = False): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for key in self._persistent.keys(): - self._persistent[key] = mode - - def state_dict(self, destination=None, prefix='', keep_vars=False): - destination = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - # Register metric states to be part of the state_dict - for key in self._defaults.keys(): - if self._persistent[key]: - current_val = getattr(self, key) - if not keep_vars: - if torch.is_tensor(current_val): - current_val = current_val.detach() - elif isinstance(current_val, list): - current_val = [cur_v.detach() if torch.is_tensor(cur_v) else cur_v for cur_v in current_val] - destination[prefix + key] = current_val - return destination - - def _filter_kwargs(self, **kwargs): - """ filter kwargs such that they match the update signature of the metric """ - - # filter all parameters based on update signature except those of - # type VAR_POSITIONAL (*args) and VAR_KEYWORD (**kwargs) - _params = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD) - filtered_kwargs = { - k: v - for k, v in kwargs.items() if k in self._update_signature.parameters.keys() - and self._update_signature.parameters[k].kind not in _params - } - - # if no kwargs filtered, return al kwargs as default - if not filtered_kwargs: - filtered_kwargs = kwargs - return filtered_kwargs - - def __hash__(self): - hash_vals = [self.__class__.__name__] - - for key in self._defaults.keys(): - val = getattr(self, key) - # Special case: allow list values, so long - # as their elements are hashable - if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend(val) - else: - hash_vals.append(val) - - return hash(tuple(hash_vals)) - - def __add__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.add, self, other) - - def __and__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_and, self, other) - - def __eq__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.eq, self, other) - - def __floordiv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.floor_divide, self, other) - - def __ge__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.ge, self, other) - - def __gt__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.gt, self, other) - - def __le__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.le, self, other) - - def __lt__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.lt, self, other) - - def __matmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.matmul, self, other) - - def __mod__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.fmod, self, other) - - def __mul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.mul, self, other) - - def __ne__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.ne, self, other) - - def __or__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_or, self, other) - - def __pow__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.pow, self, other) - - def __radd__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.add, other, self) - - def __rand__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - # swap them since bitwise_and only supports that way and it's commutative - return CompositionalMetric(torch.bitwise_and, self, other) - - def __rfloordiv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.floor_divide, other, self) - - def __rmatmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.matmul, other, self) - - def __rmod__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.fmod, other, self) - - def __rmul__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.mul, other, self) - - def __ror__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_or, other, self) - - def __rpow__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.pow, other, self) - - def __rsub__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.sub, other, self) - - def __rtruediv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.true_divide, other, self) - - def __rxor__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_xor, other, self) - - def __sub__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.sub, self, other) - - def __truediv__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.true_divide, self, other) - - def __xor__(self, other: Any): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_xor, self, other) - - def __abs__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.abs, self, None) - - def __inv__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.bitwise_not, self, None) - - def __invert__(self): - return self.__inv__() - - def __neg__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(_neg, self, None) - - def __pos__(self): - from pytorch_lightning.metrics.compositional import CompositionalMetric - - return CompositionalMetric(torch.abs, self, None) - - -def _neg(tensor: torch.Tensor): - return -torch.abs(tensor) - - -class MetricCollection(nn.ModuleDict): - """ - MetricCollection class can be used to chain metrics that have the same - call pattern into one single class. - - Args: - metrics: One of the following - - * list or tuple: if metrics are passed in as a list, will use the - metrics class name as key for output dict. Therefore, two metrics - of the same class cannot be chained this way. - - * dict: if metrics are passed in as a dict, will use each key in the - dict as key for output dict. Use this format if you want to chain - together multiple of the same metric with different parameters. - - Raises: - ValueError: - If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``. - ValueError: - If two elements in ``metrics`` have the same ``name``. - ValueError: - If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``. - - Example (input as list): - - >>> from pytorch_lightning.metrics import MetricCollection, Accuracy, Precision, Recall - >>> target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) - >>> preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) - >>> metrics = MetricCollection([Accuracy(), - ... Precision(num_classes=3, average='macro'), - ... Recall(num_classes=3, average='macro')]) - >>> metrics(preds, target) - {'Accuracy': tensor(0.1250), 'Precision': tensor(0.0667), 'Recall': tensor(0.1111)} - - Example (input as dict): - >>> metrics = MetricCollection({'micro_recall': Recall(num_classes=3, average='micro'), - ... 'macro_recall': Recall(num_classes=3, average='macro')}) - >>> metrics(preds, target) - {'micro_recall': tensor(0.1250), 'macro_recall': tensor(0.1111)} +class MetricCollection(__MetricCollection): + r""" + This implementation refers to :class:`~torchmetrics.MetricCollection`. + .. warning:: This metric is deprecated, use ``torchmetrics.MetricCollection``. Will be removed in v1.5.0. """ def __init__(self, metrics: Union[List[Metric], Tuple[Metric], Dict[str, Metric]]): - super().__init__() - if isinstance(metrics, dict): - # Check all values are metrics - for name, metric in metrics.items(): - if not isinstance(metric, Metric): - raise ValueError( - f"Value {metric} belonging to key {name}" - " is not an instance of `pl.metrics.Metric`" - ) - self[name] = metric - elif isinstance(metrics, (tuple, list)): - for metric in metrics: - if not isinstance(metric, Metric): - raise ValueError( - f"Input {metric} to `MetricCollection` is not a instance" - " of `pl.metrics.Metric`" - ) - name = metric.__class__.__name__ - if name in self: - raise ValueError(f"Encountered two metrics both named {name}") - self[name] = metric - else: - raise ValueError("Unknown input to MetricCollection.") - - def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 - """ - Iteratively call forward for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} - - def update(self, *args, **kwargs): # pylint: disable=E0202 - """ - Iteratively call update for each metric. Positional arguments (args) will - be passed to every metric in the collection, while keyword arguments (kwargs) - will be filtered based on the signature of the individual metric. - """ - for _, m in self.items(): - m_kwargs = m._filter_kwargs(**kwargs) - m.update(*args, **m_kwargs) - - def compute(self) -> Dict[str, Any]: - return {k: m.compute() for k, m in self.items()} - - def reset(self): - """ Iteratively call reset for each metric """ - for _, m in self.items(): - m.reset() - - def clone(self): - """ Make a copy of the metric collection """ - return deepcopy(self) - - def persistent(self, mode: bool = True): - """Method for post-init to change if metric states should be saved to - its state_dict - """ - for _, m in self.items(): - m.persistent(mode) + rank_zero_warn( + "This `MetricCollection` was deprecated since v1.3.0 in favor of `torchmetrics.MetricCollection`." + " It will be removed in v1.5.0", DeprecationWarning + ) + super().__init__(metrics=metrics) diff --git a/pytorch_lightning/metrics/utils.py b/pytorch_lightning/metrics/utils.py index f93ad040f1d99..63c6892cb2987 100644 --- a/pytorch_lightning/metrics/utils.py +++ b/pytorch_lightning/metrics/utils.py @@ -11,86 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple +from typing import List, Optional import torch +from torchmetrics.utilities.data import dim_zero_cat as __dim_zero_cat +from torchmetrics.utilities.data import dim_zero_mean as __dim_zero_mean +from torchmetrics.utilities.data import dim_zero_sum as __dim_zero_sum +from torchmetrics.utilities.data import get_num_classes as __get_num_classes +from torchmetrics.utilities.data import select_topk as __select_topk +from torchmetrics.utilities.data import to_categorical as __to_categorical +from torchmetrics.utilities.data import to_onehot as __to_onehot +from torchmetrics.utilities.distributed import class_reduce as __class_reduce +from torchmetrics.utilities.distributed import reduce as __reduce from pytorch_lightning.utilities import rank_zero_warn -METRIC_EPS = 1e-6 - def dim_zero_cat(x): - x = x if isinstance(x, (list, tuple)) else [x] - return torch.cat(x, dim=0) + rank_zero_warn( + "This `dim_zero_cat` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning + ) + return __dim_zero_cat(x) def dim_zero_sum(x): - return torch.sum(x, dim=0) + rank_zero_warn( + "This `dim_zero_sum` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning + ) + return __dim_zero_sum(x) def dim_zero_mean(x): - return torch.mean(x, dim=0) - - -def _flatten(x): - return [item for sublist in x for item in sublist] - - -def _check_same_shape(pred: torch.Tensor, target: torch.Tensor): - """ Check that predictions and target have the same shape, else raise error """ - if pred.shape != target.shape: - raise RuntimeError("Predictions and targets are expected to have the same shape") - - -def _input_format_classification_one_hot( - num_classes: int, - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - multilabel: bool = False -) -> Tuple[torch.Tensor, torch.Tensor]: - """Convert preds and target tensors into one hot spare label tensors - - Args: - num_classes: number of classes - preds: either tensor with labels, tensor with probabilities/logits or - multilabel tensor - target: tensor with ground true labels - threshold: float used for thresholding multilabel input - multilabel: boolean flag indicating if input is multilabel - - Raises: - ValueError: - If ``preds`` and ``target`` don't have the same number of dimensions - or one additional dimension for ``preds``. - - Returns: - preds: one hot tensor of shape [num_classes, -1] with predicted labels - target: one hot tensors of shape [num_classes, -1] with true labels - """ - if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1): - raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds") - - if preds.ndim == target.ndim + 1: - # multi class probabilites - preds = torch.argmax(preds, dim=1) - - if preds.ndim == target.ndim and preds.dtype in (torch.long, torch.int) and num_classes > 1 and not multilabel: - # multi-class - preds = to_onehot(preds, num_classes=num_classes) - target = to_onehot(target, num_classes=num_classes) - - elif preds.ndim == target.ndim and preds.is_floating_point(): - # binary or multilabel probablities - preds = (preds >= threshold).long() - - # transpose class as first dim and reshape - if preds.ndim > 1: - preds = preds.transpose(1, 0) - target = target.transpose(1, 0) - - return preds.reshape(num_classes, -1), target.reshape(num_classes, -1) + rank_zero_warn( + "This `dim_zero_mean` was deprecated since v1.3.0 and it will be removed in v1.5.0", DeprecationWarning + ) + return __dim_zero_mean(x) def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: @@ -122,211 +77,79 @@ def get_group_indexes(idx: torch.Tensor) -> List[torch.Tensor]: return [torch.tensor(x, dtype=torch.int64) for x in indexes.values()] -def to_onehot( - label_tensor: torch.Tensor, - num_classes: Optional[int] = None, -) -> torch.Tensor: - """ - Converts a dense label tensor to one-hot format - - Args: - label_tensor: dense label tensor, with shape [N, d1, d2, ...] - num_classes: number of classes C - - Returns: - A sparse label tensor with shape [N, C, d1, d2, ...] - - Example: - - >>> from pytorch_lightning.metrics.utils import to_onehot - >>> x = torch.tensor([1, 2, 3]) - >>> to_onehot(x) - tensor([[0, 1, 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1]]) +def to_onehot(label_tensor: torch.Tensor, num_classes: Optional[int] = None) -> torch.Tensor: + r""" + .. warning:: This function is deprecated, use ``torchmetrics.utilities.data.to_onehot``. Will be removed in v1.5.0. """ - if num_classes is None: - num_classes = int(label_tensor.max().detach().item() + 1) - - tensor_onehot = torch.zeros( - label_tensor.shape[0], - num_classes, - *label_tensor.shape[1:], - dtype=label_tensor.dtype, - device=label_tensor.device, + rank_zero_warn( + "This `to_onehot` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_onehot`." + " It will be removed in v1.5.0", DeprecationWarning ) - index = label_tensor.long().unsqueeze(1).expand_as(tensor_onehot) - return tensor_onehot.scatter_(1, index, 1.0) + return __to_onehot(label_tensor=label_tensor, num_classes=num_classes) def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch.Tensor: - """ - Convert a probability tensor to binary by selecting top-k highest entries. - - Args: - prob_tensor: dense tensor of shape ``[..., C, ...]``, where ``C`` is in the - position defined by the ``dim`` argument - topk: number of highest entries to turn into 1s - dim: dimension on which to compare entries + r""" + .. warning:: - Returns: - A binary tensor of the same shape as the input tensor of type torch.int32 - - Example: - - >>> from pytorch_lightning.metrics.utils import select_topk - >>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) - >>> select_topk(x, topk=2) - tensor([[0, 1, 1], - [1, 1, 0]], dtype=torch.int32) + This function is deprecated, use ``torchmetrics.utilities.data.select_topk``. Will be removed in v1.5.0. """ - zeros = torch.zeros_like(prob_tensor) - topk_tensor = zeros.scatter(dim, prob_tensor.topk(k=topk, dim=dim).indices, 1.0) - return topk_tensor.int() + rank_zero_warn( + "This `select_topk` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.select_topk`." + " It will be removed in v1.5.0", DeprecationWarning + ) + return __select_topk(prob_tensor=prob_tensor, topk=topk, dim=dim) def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor: - """ - Converts a tensor of probabilities to a dense label tensor - - Args: - tensor: probabilities to get the categorical label [N, d1, d2, ...] - argmax_dim: dimension to apply + r""" + .. warning:: - Return: - A tensor with categorical labels [N, d2, ...] - - Example: - - >>> from pytorch_lightning.metrics.utils import to_categorical - >>> x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) - >>> to_categorical(x) - tensor([1, 0]) + This function is deprecated, use ``torchmetrics.utilities.data.to_categorical``. Will be removed in v1.5.0. """ - return torch.argmax(tensor, dim=argmax_dim) - + rank_zero_warn( + "This `to_categorical` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.to_categorical`." + " It will be removed in v1.5.0", DeprecationWarning + ) + return __to_categorical(tensor=tensor, argmax_dim=argmax_dim) -def get_num_classes( - pred: torch.Tensor, - target: torch.Tensor, - num_classes: Optional[int] = None, -) -> int: - """ - Calculates the number of classes for a given prediction and target tensor. - Args: - pred: predicted values - target: true labels - num_classes: number of classes if known +def get_num_classes(pred: torch.Tensor, target: torch.Tensor, num_classes: Optional[int] = None) -> int: + r""" + .. warning:: - Return: - An integer that represents the number of classes. + This function is deprecated, use ``torchmetrics.utilities.data.get_num_classes``. Will be removed in v1.5.0. """ - num_target_classes = int(target.max().detach().item() + 1) - num_pred_classes = int(pred.max().detach().item() + 1) - num_all_classes = max(num_target_classes, num_pred_classes) - - if num_classes is None: - num_classes = num_all_classes - elif num_classes != num_all_classes: - rank_zero_warn( - f"You have set {num_classes} number of classes which is" - f" different from predicted ({num_pred_classes}) and" - f" target ({num_target_classes}) number of classes", - RuntimeWarning, - ) - return num_classes + rank_zero_warn( + "This `get_num_classes` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.data.get_num_classes`." + " It will be removed in v1.5.0", DeprecationWarning + ) + return __get_num_classes(pred=pred, target=target, num_classes=num_classes) def reduce(to_reduce: torch.Tensor, reduction: str) -> torch.Tensor: - """ - Reduces a given tensor by a given reduction method - - Args: - to_reduce : the tensor, which shall be reduced - reduction : a string specifying the reduction method ('elementwise_mean', 'none', 'sum') - - Return: - reduced Tensor + r""" + .. warning:: - Raise: - ValueError if an invalid reduction parameter was given + This function is deprecated, use ``torchmetrics.utilities.reduce``. Will be removed in v1.5.0. """ - if reduction == "elementwise_mean": - return torch.mean(to_reduce) - if reduction == "none": - return to_reduce - if reduction == "sum": - return torch.sum(to_reduce) - raise ValueError("Reduction parameter unknown.") + rank_zero_warn( + "This `reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.reduce`." + " It will be removed in v1.5.0", DeprecationWarning + ) + return __reduce(to_reduce=to_reduce, reduction=reduction) def class_reduce( num: torch.Tensor, denom: torch.Tensor, weights: torch.Tensor, class_reduction: str = "none" ) -> torch.Tensor: - """ - Function used to reduce classification metrics of the form `num / denom * weights`. - For example for calculating standard accuracy the num would be number of - true positives per class, denom would be the support per class, and weights - would be a tensor of 1s + r""" + .. warning:: - Args: - num: numerator tensor - denom: denominator tensor - weights: weights for each class - class_reduction: reduction method for multiclass problems - - - ``'micro'``: calculate metrics globally (default) - - ``'macro'``: calculate metrics for each label, and find their unweighted mean. - - ``'weighted'``: calculate metrics for each label, and find their weighted mean. - - ``'none'`` or ``None``: returns calculated metric per class - - Raises: - ValueError: - If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``. + This function is deprecated, use ``torchmetrics.utilities.class_reduce``. Will be removed in v1.5.0. """ - valid_reduction = ("micro", "macro", "weighted", "none", None) - if class_reduction == "micro": - fraction = torch.sum(num) / torch.sum(denom) - else: - fraction = num / denom - - # We need to take care of instances where the denom can be 0 - # for some (or all) classes which will produce nans - fraction[fraction != fraction] = 0 - - if class_reduction == "micro": - return fraction - elif class_reduction == "macro": - return torch.mean(fraction) - elif class_reduction == "weighted": - return torch.sum(fraction * (weights.float() / torch.sum(weights))) - elif class_reduction == "none" or class_reduction is None: - return fraction - - raise ValueError( - f"Reduction parameter {class_reduction} unknown." - f" Choose between one of these: {valid_reduction}" + rank_zero_warn( + "This `class_reduce` was deprecated since v1.3.0 in favor of `torchmetrics.utilities.class_reduce`." + " It will be removed in v1.5.0", DeprecationWarning ) - - -def _stable_1d_sort(x: torch, N: int = 2049): - """ - Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm - if number of elements are larger than 2048. This function pads the tensors, - makes the sort and returns the sorted array (with the padding removed) - See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714 - - Raises: - ValueError: - If dim of ``x`` is greater than 1 since stable sort works with only 1d tensors. - """ - if x.ndim > 1: - raise ValueError('Stable sort only works on 1d tensors') - n = x.numel() - if N - n > 0: - x_max = x.max() - x = torch.cat([x, (x_max + 1) * torch.ones(N - n, dtype=x.dtype, device=x.device)], 0) - x_sort = x.sort() - i = min(N, n) - return x_sort.values[:i], x_sort.indices[:i] + return __class_reduce(num=num, denom=denom, weights=weights, class_reduction=class_reduction) diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index faf528d76b768..3e0f57daef001 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from contextlib import suppress -from typing import List, Optional, Callable +from typing import Callable, List, Optional import torch diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 336c16f0f1a03..ba26fc9f58ec5 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,7 @@ # limitations under the License import logging import os -from typing import List, Optional, Callable +from typing import Callable, List, Optional import torch import torch.distributed as torch_distrib diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5aa9f1a44276b..b44ba870d96d5 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import Any, Callable, Dict, List, Type, Optional +from typing import Any, Callable, Dict, List, Optional, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule diff --git a/tests/deprecated_api/test_remove_1-5_metrics.py b/tests/deprecated_api/test_remove_1-5_metrics.py new file mode 100644 index 0000000000000..b2fa4f69f74b9 --- /dev/null +++ b/tests/deprecated_api/test_remove_1-5_metrics.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.5.0""" + +import pytest +import torch + +from pytorch_lightning.metrics.utils import get_num_classes, select_topk, to_categorical, to_onehot + + +def test_v1_5_0_metrics_utils(): + x = torch.tensor([1, 2, 3]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_onehot(x), torch.Tensor([[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]).to(int)) + + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert get_num_classes(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 0])) == 4 + + x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(select_topk(x, topk=2), torch.Tensor([[0, 1, 1], [1, 1, 0]]).to(torch.int32)) + + x = torch.tensor([[0.2, 0.5], [0.9, 0.1]]) + with pytest.deprecated_call(match="It will be removed in v1.5.0"): + assert torch.equal(to_categorical(x), torch.Tensor([1, 0]).to(int)) diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index a78d799b1a07d..2b7be8caa7a0d 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -1,9 +1,9 @@ import pytest import torch from torch import rand, randint +from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType -from pytorch_lightning.metrics.utils import select_topk, to_onehot from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 39622c4cd3550..bca50867dcb44 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -1,10 +1,10 @@ import pytest import torch +from torchmetrics.utilities.data import get_num_classes, to_categorical, to_onehot from pytorch_lightning import seed_everything from pytorch_lightning.metrics.functional.classification import dice_score from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve -from pytorch_lightning.metrics.utils import get_num_classes, to_categorical, to_onehot def test_onehot(): diff --git a/tests/metrics/functional/test_reduction.py b/tests/metrics/functional/test_reduction.py index 03a34f6c5a25b..9949c8086a44a 100644 --- a/tests/metrics/functional/test_reduction.py +++ b/tests/metrics/functional/test_reduction.py @@ -1,7 +1,6 @@ import pytest import torch - -from pytorch_lightning.metrics.utils import class_reduce, reduce +from torchmetrics.utilities import class_reduce, reduce def test_reduce(): diff --git a/tests/metrics/test_composition.py b/tests/metrics/test_composition.py deleted file mode 100644 index 7845e86f514ff..0000000000000 --- a/tests/metrics/test_composition.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from operator import neg, pos - -import pytest -import torch - -from pytorch_lightning.metrics.compositional import CompositionalMetric -from pytorch_lightning.metrics.metric import Metric -from tests.helpers.runif import RunIf - - -class DummyMetric(Metric): - - def __init__(self, val_to_return): - super().__init__() - self._num_updates = 0 - self._val_to_return = val_to_return - - def update(self, *args, **kwargs) -> None: - self._num_updates += 1 - - def compute(self): - return torch.tensor(self._val_to_return) - - def reset(self): - self._num_updates = 0 - return super().reset() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - (torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_add(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_add = first_metric + second_operand - final_radd = second_operand + first_metric - - assert isinstance(final_add, CompositionalMetric) - assert isinstance(final_radd, CompositionalMetric) - - assert torch.allclose(expected_result, final_add.compute()) - assert torch.allclose(expected_result, final_radd.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric(3), torch.tensor(2)), (3, torch.tensor(2)), (3, torch.tensor(2)), (torch.tensor(3), torch.tensor(2))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_and(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_and = first_metric & second_operand - final_rand = second_operand & first_metric - - assert isinstance(final_and, CompositionalMetric) - assert isinstance(final_rand, CompositionalMetric) - - assert torch.allclose(expected_result, final_and.compute()) - assert torch.allclose(expected_result, final_rand.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_eq(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_eq = first_metric == second_operand - - assert isinstance(final_eq, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_eq.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(2)), - (2, torch.tensor(2)), - (2.0, torch.tensor(2.0)), - (torch.tensor(2), torch.tensor(2)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_floordiv(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_floordiv = first_metric // second_operand - - assert isinstance(final_floordiv, CompositionalMetric) - - assert torch.allclose(expected_result, final_floordiv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_ge(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_ge = first_metric >= second_operand - - assert isinstance(final_ge, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_ge.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(True)), - (2, torch.tensor(True)), - (2.0, torch.tensor(True)), - (torch.tensor(2), torch.tensor(True)), - ], -) -def test_metrics_gt(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_gt = first_metric > second_operand - - assert isinstance(final_gt, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_gt.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_le(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_le = first_metric <= second_operand - - assert isinstance(final_le, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_le.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_lt(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_lt = first_metric < second_operand - - assert isinstance(final_lt, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_lt.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([2, 2, 2]), torch.tensor(12)), (torch.tensor([2, 2, 2]), torch.tensor(12))], -) -def test_metrics_matmul(second_operand, expected_result): - first_metric = DummyMetric([2, 2, 2]) - - final_matmul = first_metric @ second_operand - - assert isinstance(final_matmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_matmul.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1)), - (torch.tensor(2), torch.tensor(1)), - ], -) -def test_metrics_mod(second_operand, expected_result): - first_metric = DummyMetric(5) - - final_mod = first_metric % second_operand - - assert isinstance(final_mod, CompositionalMetric) - # prevent Runtime error for PT 1.8 - Long did not match Float - assert torch.allclose(expected_result.to(float), final_mod.compute().to(float)) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(4)), - (2, torch.tensor(4)), - (2.0, torch.tensor(4.0)), - (torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_mul(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_mul = first_metric * second_operand - final_rmul = second_operand * first_metric - - assert isinstance(final_mul, CompositionalMetric) - assert isinstance(final_rmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_mul.compute()) - assert torch.allclose(expected_result, final_rmul.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(False)), - (2, torch.tensor(False)), - (2.0, torch.tensor(False)), - (torch.tensor(2), torch.tensor(False)), - ], -) -def test_metrics_ne(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_ne = first_metric != second_operand - - assert isinstance(final_ne, CompositionalMetric) - - # can't use allclose for bool tensors - assert (expected_result == final_ne.compute()).all() - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-1, -2, 3])), (torch.tensor([1, 0, 3]), torch.tensor([-1, -2, 3]))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_or(second_operand, expected_result): - first_metric = DummyMetric([-1, -2, 3]) - - final_or = first_metric | second_operand - final_ror = second_operand | first_metric - - assert isinstance(final_or, CompositionalMetric) - assert isinstance(final_ror, CompositionalMetric) - - assert torch.allclose(expected_result, final_or.compute()) - assert torch.allclose(expected_result, final_ror.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), - pytest.param(torch.tensor(2), torch.tensor(4)), - ], -) -def test_metrics_pow(second_operand, expected_result): - first_metric = DummyMetric(2) - - final_pow = first_metric**second_operand - - assert isinstance(final_pow, CompositionalMetric) - - assert torch.allclose(expected_result, final_pow.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [(5, torch.tensor(2)), (5.0, torch.tensor(2.0)), (torch.tensor(5), torch.tensor(2))], -) -@RunIf(min_torch="1.5.0") -def test_metrics_rfloordiv(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rfloordiv = first_operand // second_operand - - assert isinstance(final_rfloordiv, CompositionalMetric) - assert torch.allclose(expected_result, final_rfloordiv.compute()) - - -@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor([2, 2, 2]), torch.tensor(12))]) -def test_metrics_rmatmul(first_operand, expected_result): - second_operand = DummyMetric([2, 2, 2]) - - final_rmatmul = first_operand @ second_operand - - assert isinstance(final_rmatmul, CompositionalMetric) - - assert torch.allclose(expected_result, final_rmatmul.compute()) - - -@pytest.mark.parametrize(["first_operand", "expected_result"], [(torch.tensor(2), torch.tensor(2))]) -def test_metrics_rmod(first_operand, expected_result): - second_operand = DummyMetric(5) - - final_rmod = first_operand % second_operand - - assert isinstance(final_rmod, CompositionalMetric) - - assert torch.allclose(expected_result, final_rmod.compute()) - - -@pytest.mark.parametrize( - "first_operand,expected_result", - [ - pytest.param(DummyMetric(2), torch.tensor(4)), - pytest.param(2, torch.tensor(4)), - pytest.param(2.0, torch.tensor(4.0), marks=RunIf(min_torch="1.6.0")), - ], -) -def test_metrics_rpow(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rpow = first_operand**second_operand - - assert isinstance(final_rpow, CompositionalMetric) - - assert torch.allclose(expected_result, final_rpow.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [ - (DummyMetric(3), torch.tensor(1)), - (3, torch.tensor(1)), - (3.0, torch.tensor(1.0)), - (torch.tensor(3), torch.tensor(1)), - ], -) -def test_metrics_rsub(first_operand, expected_result): - second_operand = DummyMetric(2) - - final_rsub = first_operand - second_operand - - assert isinstance(final_rsub, CompositionalMetric) - - assert torch.allclose(expected_result, final_rsub.compute()) - - -@pytest.mark.parametrize( - ["first_operand", "expected_result"], - [ - (DummyMetric(6), torch.tensor(2.0)), - (6, torch.tensor(2.0)), - (6.0, torch.tensor(2.0)), - (torch.tensor(6), torch.tensor(2.0)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_rtruediv(first_operand, expected_result): - second_operand = DummyMetric(3) - - final_rtruediv = first_operand / second_operand - - assert isinstance(final_rtruediv, CompositionalMetric) - - assert torch.allclose(expected_result, final_rtruediv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(2), torch.tensor(1)), - (2, torch.tensor(1)), - (2.0, torch.tensor(1.0)), - (torch.tensor(2), torch.tensor(1)), - ], -) -def test_metrics_sub(second_operand, expected_result): - first_metric = DummyMetric(3) - - final_sub = first_metric - second_operand - - assert isinstance(final_sub, CompositionalMetric) - - assert torch.allclose(expected_result, final_sub.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [ - (DummyMetric(3), torch.tensor(2.0)), - (3, torch.tensor(2.0)), - (3.0, torch.tensor(2.0)), - (torch.tensor(3), torch.tensor(2.0)), - ], -) -@RunIf(min_torch="1.5.0") -def test_metrics_truediv(second_operand, expected_result): - first_metric = DummyMetric(6) - - final_truediv = first_metric / second_operand - - assert isinstance(final_truediv, CompositionalMetric) - - assert torch.allclose(expected_result, final_truediv.compute()) - - -@pytest.mark.parametrize( - ["second_operand", "expected_result"], - [(DummyMetric([1, 0, 3]), torch.tensor([-2, -2, 0])), (torch.tensor([1, 0, 3]), torch.tensor([-2, -2, 0]))], -) -def test_metrics_xor(second_operand, expected_result): - first_metric = DummyMetric([-1, -2, 3]) - - final_xor = first_metric ^ second_operand - final_rxor = second_operand ^ first_metric - - assert isinstance(final_xor, CompositionalMetric) - assert isinstance(final_rxor, CompositionalMetric) - - assert torch.allclose(expected_result, final_xor.compute()) - assert torch.allclose(expected_result, final_rxor.compute()) - - -def test_metrics_abs(): - first_metric = DummyMetric(-1) - - final_abs = abs(first_metric) - - assert isinstance(final_abs, CompositionalMetric) - - assert torch.allclose(torch.tensor(1), final_abs.compute()) - - -def test_metrics_invert(): - first_metric = DummyMetric(1) - - final_inverse = ~first_metric - assert isinstance(final_inverse, CompositionalMetric) - assert torch.allclose(torch.tensor(-2), final_inverse.compute()) - - -def test_metrics_neg(): - first_metric = DummyMetric(1) - - final_neg = neg(first_metric) - assert isinstance(final_neg, CompositionalMetric) - assert torch.allclose(torch.tensor(-1), final_neg.compute()) - - -def test_metrics_pos(): - first_metric = DummyMetric(-1) - - final_pos = pos(first_metric) - assert isinstance(final_pos, CompositionalMetric) - assert torch.allclose(torch.tensor(1), final_pos.compute()) - - -def test_compositional_metrics_update(): - - compos = DummyMetric(5) + DummyMetric(4) - - assert isinstance(compos, CompositionalMetric) - compos.update() - compos.update() - compos.update() - - assert isinstance(compos.metric_a, DummyMetric) - assert isinstance(compos.metric_b, DummyMetric) - - assert compos.metric_a._num_updates == 3 - assert compos.metric_b._num_updates == 3 diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py deleted file mode 100644 index 5120cce0a0425..0000000000000 --- a/tests/metrics/test_ddp.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -import torch - -from pytorch_lightning.metrics import Metric -from tests.helpers.runif import RunIf -from tests.metrics.test_metric import Dummy -from tests.metrics.utils import setup_ddp - -torch.manual_seed(42) - - -def _test_ddp_sum(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.sum} - dummy.foo = torch.tensor(1) - - dummy._sync_dist() - assert dummy.foo == worldsize - - -def _test_ddp_cat(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.cat} - dummy.foo = [torch.tensor([1])] - dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) - - -def _test_ddp_sum_cat(rank, worldsize): - setup_ddp(rank, worldsize) - dummy = Dummy() - dummy._reductions = {"foo": torch.cat, "bar": torch.sum} - dummy.foo = [torch.tensor([1])] - dummy.bar = torch.tensor(1) - dummy._sync_dist() - assert torch.all(torch.eq(dummy.foo, torch.tensor([1, 1]))) - assert dummy.bar == worldsize - - -@RunIf(skip_windows=True) -@pytest.mark.parametrize("process", [_test_ddp_cat, _test_ddp_sum, _test_ddp_sum_cat]) -def test_ddp(process): - torch.multiprocessing.spawn(process, args=(2, ), nprocs=2) - - -def _test_non_contiguous_tensors(rank, worldsize): - setup_ddp(rank, worldsize) - - class DummyMetric(Metric): - - def __init__(self): - super().__init__() - self.add_state("x", default=[], dist_reduce_fx=None) - - def update(self, x): - self.x.append(x) - - def compute(self): - x = torch.cat(self.x, dim=0) - return x.sum() - - metric = DummyMetric() - metric.update(torch.randn(10, 5)[:, 0]) - - -@RunIf(skip_windows=True) -def test_non_contiguous_tensors(): - """ Test that gather_all operation works for non contiguous tensors """ - torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2, ), nprocs=2) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py deleted file mode 100644 index ad7b4566dc012..0000000000000 --- a/tests/metrics/test_metric.py +++ /dev/null @@ -1,395 +0,0 @@ -import pickle -from collections import OrderedDict -from distutils.version import LooseVersion - -import cloudpickle -import numpy as np -import pytest -import torch -from torch import nn - -from pytorch_lightning.metrics.metric import Metric, MetricCollection -from tests.helpers.runif import RunIf - -torch.manual_seed(42) - - -class Dummy(Metric): - name = "Dummy" - - def __init__(self): - super().__init__() - self.add_state("x", torch.tensor(0.0), dist_reduce_fx=None) - - def update(self): - pass - - def compute(self): - pass - - -class DummyList(Metric): - name = "DummyList" - - def __init__(self): - super().__init__() - self.add_state("x", list(), dist_reduce_fx=None) - - def update(self): - pass - - def compute(self): - pass - - -def test_inherit(): - Dummy() - - -def test_add_state(): - a = Dummy() - - a.add_state("a", torch.tensor(0), "sum") - assert a._reductions["a"](torch.tensor([1, 1])) == 2 - - a.add_state("b", torch.tensor(0), "mean") - assert np.allclose(a._reductions["b"](torch.tensor([1.0, 2.0])).numpy(), 1.5) - - a.add_state("c", torch.tensor(0), "cat") - assert a._reductions["c"]([torch.tensor([1]), torch.tensor([1])]).shape == (2, ) - - with pytest.raises(ValueError): - a.add_state("d1", torch.tensor(0), 'xyz') - - with pytest.raises(ValueError): - a.add_state("d2", torch.tensor(0), 42) - - with pytest.raises(ValueError): - a.add_state("d3", [torch.tensor(0)], 'sum') - - with pytest.raises(ValueError): - a.add_state("d4", 42, 'sum') - - def custom_fx(x): - return -1 - - a.add_state("e", torch.tensor(0), custom_fx) - assert a._reductions["e"](torch.tensor([1, 1])) == -1 - - -def test_add_state_persistent(): - a = Dummy() - - a.add_state("a", torch.tensor(0), "sum", persistent=True) - assert "a" in a.state_dict() - - a.add_state("b", torch.tensor(0), "sum", persistent=False) - - if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): - assert "b" not in a.state_dict() - - -def test_reset(): - - class A(Dummy): - pass - - class B(DummyList): - pass - - a = A() - assert a.x == 0 - a.x = torch.tensor(5) - a.reset() - assert a.x == 0 - - b = B() - assert isinstance(b.x, list) and len(b.x) == 0 - b.x = torch.tensor(5) - b.reset() - assert isinstance(b.x, list) and len(b.x) == 0 - - -def test_update(): - - class A(Dummy): - - def update(self, x): - self.x += x - - a = A() - assert a.x == 0 - assert a._computed is None - a.update(1) - assert a._computed is None - assert a.x == 1 - a.update(2) - assert a.x == 3 - assert a._computed is None - - -def test_compute(): - - class A(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A() - assert 0 == a.compute() - assert 0 == a.x - a.update(1) - assert a._computed is None - assert a.compute() == 1 - assert a._computed == 1 - a.update(2) - assert a._computed is None - assert a.compute() == 3 - assert a._computed == 3 - - # called without update, should return cached value - a._computed = 5 - assert a.compute() == 5 - - -def test_hash(): - - class A(Dummy): - pass - - class B(DummyList): - pass - - a1 = A() - a2 = A() - assert hash(a1) != hash(a2) - - b1 = B() - b2 = B() - assert hash(b1) == hash(b2) - assert isinstance(b1.x, list) and len(b1.x) == 0 - b1.x.append(torch.tensor(5)) - assert isinstance(hash(b1), int) # <- check that nothing crashes - assert isinstance(b1.x, list) and len(b1.x) == 1 - b2.x.append(torch.tensor(5)) - # Sanity: - assert isinstance(b2.x, list) and len(b2.x) == 1 - # Now that they have tensor contents, they should have different hashes: - assert hash(b1) != hash(b2) - - -def test_forward(): - - class A(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - a = A() - assert a(5) == 5 - assert a._forward_cache == 5 - - assert a(8) == 8 - assert a._forward_cache == 8 - - assert a.compute() == 13 - - -class DummyMetric1(Dummy): - - def update(self, x): - self.x += x - - def compute(self): - return self.x - - -class DummyMetric2(Dummy): - - def update(self, y): - self.x -= y - - def compute(self): - return self.x - - -def test_pickle(tmpdir): - # doesn't tests for DDP - a = DummyMetric1() - a.update(1) - - metric_pickled = pickle.dumps(a) - metric_loaded = pickle.loads(metric_pickled) - - assert metric_loaded.compute() == 1 - - metric_loaded.update(5) - assert metric_loaded.compute() == 6 - - metric_pickled = cloudpickle.dumps(a) - metric_loaded = cloudpickle.loads(metric_pickled) - - assert metric_loaded.compute() == 1 - - -def test_state_dict(tmpdir): - """ test that metric states can be removed and added to state dict """ - metric = Dummy() - assert metric.state_dict() == OrderedDict() - metric.persistent(True) - assert metric.state_dict() == OrderedDict(x=0) - metric.persistent(False) - assert metric.state_dict() == OrderedDict() - - -def test_child_metric_state_dict(): - """ test that child metric states will be added to parent state dict """ - - class TestModule(nn.Module): - - def __init__(self): - super().__init__() - self.metric = Dummy() - self.metric.add_state('a', torch.tensor(0), persistent=True) - self.metric.add_state('b', [], persistent=True) - self.metric.register_buffer('c', torch.tensor(0)) - - module = TestModule() - expected_state_dict = { - 'metric.a': torch.tensor(0), - 'metric.b': [], - 'metric.c': torch.tensor(0), - } - assert module.state_dict() == expected_state_dict - - -@RunIf(min_gpus=1) -def test_device_and_dtype_transfer(tmpdir): - metric = DummyMetric1() - assert metric.x.is_cuda is False - assert metric.x.dtype == torch.float32 - - metric = metric.to(device='cuda') - assert metric.x.is_cuda - - metric = metric.double() - assert metric.x.dtype == torch.float64 - - metric = metric.half() - assert metric.x.dtype == torch.float16 - - -def test_metric_collection(tmpdir): - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - - # Test correct dict structure - assert len(metric_collection) == 2 - assert metric_collection['DummyMetric1'] == m1 - assert metric_collection['DummyMetric2'] == m2 - - # Test correct initialization - for name, metric in metric_collection.items(): - assert metric.x == 0, f'Metric {name} not initialized correctly' - - # Test every metric gets updated - metric_collection.update(5) - for name, metric in metric_collection.items(): - assert metric.x.abs() == 5, f'Metric {name} not updated correctly' - - # Test compute on each metric - metric_collection.update(-5) - metric_vals = metric_collection.compute() - assert len(metric_vals) == 2 - for name, metric_val in metric_vals.items(): - assert metric_val == 0, f'Metric {name}.compute not called correctly' - - # Test that everything is reset - for name, metric in metric_collection.items(): - assert metric.x == 0, f'Metric {name} not reset correctly' - - # Test pickable - metric_pickled = pickle.dumps(metric_collection) - metric_loaded = pickle.loads(metric_pickled) - assert isinstance(metric_loaded, MetricCollection) - - -@RunIf(min_gpus=1) -def test_device_and_dtype_transfer_metriccollection(tmpdir): - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - for _, metric in metric_collection.items(): - assert metric.x.is_cuda is False - assert metric.x.dtype == torch.float32 - - metric_collection = metric_collection.to(device='cuda') - for _, metric in metric_collection.items(): - assert metric.x.is_cuda - - metric_collection = metric_collection.double() - for _, metric in metric_collection.items(): - assert metric.x.dtype == torch.float64 - - metric_collection = metric_collection.half() - for _, metric in metric_collection.items(): - assert metric.x.dtype == torch.float16 - - -def test_metric_collection_wrong_input(tmpdir): - """ Check that errors are raised on wrong input """ - m1 = DummyMetric1() - - # Not all input are metrics (list) - with pytest.raises(ValueError): - _ = MetricCollection([m1, 5]) - - # Not all input are metrics (dict) - with pytest.raises(ValueError): - _ = MetricCollection({'metric1': m1, 'metric2': 5}) - - # Same metric passed in multiple times - with pytest.raises(ValueError, match='Encountered two metrics both named *.'): - _ = MetricCollection([m1, m1]) - - # Not a list or dict passed in - with pytest.raises(ValueError, match='Unknown input to MetricCollection.'): - _ = MetricCollection(m1) - - -def test_metric_collection_args_kwargs(tmpdir): - """ Check that args and kwargs gets passed correctly in metric collection, - Checks both update and forward method - """ - m1 = DummyMetric1() - m2 = DummyMetric2() - - metric_collection = MetricCollection([m1, m2]) - - # args gets passed to all metrics - metric_collection.update(5) - assert metric_collection['DummyMetric1'].x == 5 - assert metric_collection['DummyMetric2'].x == -5 - metric_collection.reset() - _ = metric_collection(5) - assert metric_collection['DummyMetric1'].x == 5 - assert metric_collection['DummyMetric2'].x == -5 - metric_collection.reset() - - # kwargs gets only passed to metrics that it matches - metric_collection.update(x=10, y=20) - assert metric_collection['DummyMetric1'].x == 10 - assert metric_collection['DummyMetric2'].x == -20 - metric_collection.reset() - _ = metric_collection(x=10, y=20) - assert metric_collection['DummyMetric1'].x == 10 - assert metric_collection['DummyMetric2'].x == -20