diff --git a/CHANGELOG.md b/CHANGELOG.md index f5dcb20375137..c17f913cc960d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), + + [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), ) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py deleted file mode 100644 index a91150799d5a1..0000000000000 --- a/pytorch_lightning/metrics/classification/helpers.py +++ /dev/null @@ -1,535 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple - -import numpy as np -import torch -from torchmetrics.utilities.data import select_topk, to_onehot - -from pytorch_lightning.utilities import LightningEnum - - -class DataType(LightningEnum): - """ - Enum to represent data type - """ - - BINARY = "binary" - MULTILABEL = "multi-label" - MULTICLASS = "multi-class" - MULTIDIM_MULTICLASS = "multi-dim multi-class" - - -class AverageMethod(LightningEnum): - """ - Enum to represent average method - """ - - MICRO = "micro" - MACRO = "macro" - WEIGHTED = "weighted" - NONE = "none" - SAMPLES = "samples" - - -class MDMCAverageMethod(LightningEnum): - """ - Enum to represent multi-dim multi-class average method - """ - - GLOBAL = "global" - SAMPLEWISE = "samplewise" - - -def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): - """ - Perform basic validation of inputs that does not require deducing any information - of the type of inputs. - """ - - if target.is_floating_point(): - raise ValueError("The `target` has to be an integer tensor.") - if target.min() < 0: - raise ValueError("The `target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("If `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("The `preds` and `target` should have the same first dimension.") - - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") - - -def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: - """ - This checks that the shape and type of inputs are consistent with - each other and fall into one of the allowed input types (see the - documentation of docstring of ``_input_format_classification``). It does - not check for consistency of number of classes, other functions take - care of that. - - It returns the name of the case in which the inputs fall, and the implied - number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for - multi-label data). - """ - - preds_float = preds.is_floating_point() - - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "The `preds` and `target` should have the same shape,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = DataType.BINARY - elif preds.ndim == 1 and not preds_float: - case = DataType.MULTICLASS - elif preds.ndim > 1 and preds_float: - case = DataType.MULTILABEL - else: - case = DataType.MULTIDIM_MULTICLASS - - implied_classes = preds[0].numel() - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "If `preds` have one dimension more than `target`, the shape of `preds` should be" - " (N, C, ...), and the shape of `target` should be (N, ...)." - ) - - implied_classes = preds.shape[1] - - if preds.ndim == 2: - case = DataType.MULTICLASS - else: - case = DataType.MULTIDIM_MULTICLASS - else: - raise ValueError( - "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" - " and `preds` should be (N, C, ...)." - ) - - return case, implied_classes - - -def _check_num_classes_binary(num_classes: int, is_multiclass: bool): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for binary data. - """ - - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - if num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - if num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either set `is_multiclass=None`(default) or set `num_classes=2`" - " to transform binary data to multi-class format." - ) - - -def _check_num_classes_mc( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int -): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for (multi-dimensional) multi-class data. - """ - - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - if num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != implied_classes: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") - - -def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for multi-label data. - """ - - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") - - -def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if case == DataType.BINARY: - raise ValueError("You can not use `top_k` parameter with binary data.") - if not isinstance(top_k, int) or top_k <= 0: - raise ValueError("The `top_k` has to be an integer larger than 0.") - if not preds_float: - raise ValueError("You have set `top_k`, but you do not have probability predictions.") - if is_multiclass is False: - raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == DataType.MULTILABEL and is_multiclass: - raise ValueError( - "If you want to transform multi-label data to 2 class multi-dimensional" - "multi-class data using `is_multiclass=True`, you can not use `top_k`." - ) - if top_k >= implied_classes: - raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") - - -def _check_classification_inputs( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float, - num_classes: Optional[int], - is_multiclass: bool, - top_k: Optional[int], -) -> str: - """Performs error checking on inputs for classification. - - This ensures that preds and target take one of the shape/type combinations that are - specified in ``_input_format_classification`` docstring. It also checks the cases of - over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class - cases) that there are only up to 2 distinct labels. - - In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. - - When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, - multi-label, ...), and that, if availible, the implied number of classes in the ``C`` - dimension is consistent with it (as well as that max label in target is smaller than it). - - When ``num_classes`` is not specified in these cases, consistency of the highest target - value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - - If ``top_k`` is set (not None) for inputs that do not have probability predictions (and - are not binary), an error is raised. Similarly if ``top_k`` is set to a number that - is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. - - Preds and target tensors are expected to be squeezed already - all dimensions should be - greater than 1, except perhaps the first one (``N``). - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. The default value (``None``) will be - interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, - it will take precedence over threshold. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - - Return: - case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' - """ - - # Baisc validation (that does not need case/type information) - _basic_input_validation(preds, target, threshold, is_multiclass) - - # Check that shape/types fall into one of the cases - case, implied_classes = _check_shape_and_type_consistency(preds, target) - - # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): - if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): - raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") - - # Check consistency with the `C` dimension in case of multi-class data - if preds.shape != target.shape: - if is_multiclass is False and implied_classes != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) - if target.max() >= implied_classes: - raise ValueError( - "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." - ) - - # Check that num_classes is consistent - if num_classes: - if case == DataType.BINARY: - _check_num_classes_binary(num_classes, is_multiclass) - elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): - _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case.MULTILABEL: - _check_num_classes_ml(num_classes, is_multiclass, implied_classes) - - # Check that top_k is consistent - if top_k is not None: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) - - return case - - -def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - top_k: Optional[int] = None, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: - """Convert preds and target tensors into common format. - - Preds and targets are supposed to fall into one of these categories (and are - validated to make sure this is the case): - - * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) - * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) - * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) - * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` - and is integer (multi-dimensional multi-class) - * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) - - To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. - - The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` - of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``case`` string, which describes which of the above cases the inputs belonged to - regardless - of whether this was "overridden" by other settings (like ``is_multiclass``). - - In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed - into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds - become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to - preds first. - - In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets - by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original - shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be - returned as ``(N,1)`` tensor. - - In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with - preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening - all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as - ``(N, 2, C)``, by an equivalent transformation as in the binary case. - - In multi-dimensional multi-class case, normally both target and preds are returned as - ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and - ``C``. The transformations performed here are equivalent to the multi-class case. However, if - ``is_multiclass=False`` (and there are up to two classes), then the data is returned as - ``(N, X)`` binary tensors (multi-label). - - Note that where a one-hot transformation needs to be performed and the number of classes - is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be - equal to ``num_classes``, if it is given, or the maximum label value in preds and - target. - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as 1 for these inputs. - - Should be left unset (``None``) for all other types of inputs. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - - Returns: - preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or - ``'multi-dim multi-class'`` - """ - # Remove excess dimensions - if preds.shape[0] == 1: - preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) - else: - preds, target = preds.squeeze(), target.squeeze() - - # Convert half precision tensors to full precision, as not all ops are supported - # for example, min() is not supported - if preds.dtype == torch.float16: - preds = preds.float() - - case = _check_classification_inputs( - preds, - target, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: - preds = (preds >= threshold).int() - num_classes = num_classes if not is_multiclass else 2 - - if case == DataType.MULTILABEL and top_k: - preds = select_topk(preds, top_k) - - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: - if preds.is_floating_point(): - num_classes = preds.shape[1] - preds = select_topk(preds, top_k or 1) - else: - num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, max(2, num_classes)) - - target = to_onehot(target, max(2, num_classes)) - - if is_multiclass is False: - preds, target = preds[:, 1, ...], target[:, 1, ...] - - if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - else: - target = target.reshape(target.shape[0], -1) - preds = preds.reshape(preds.shape[0], -1) - - # Some operatins above create an extra dimension for MC/binary case - this removes it - if preds.ndim > 2: - preds, target = preds.squeeze(-1), target.squeeze(-1) - - return preds.int(), target.int(), case - - -def _reduce_stat_scores( - numerator: torch.Tensor, - denominator: torch.Tensor, - weights: Optional[torch.Tensor], - average: str, - mdmc_average: Optional[str], - zero_division: int = 0, -) -> torch.Tensor: - """ - Reduces scores of type ``numerator/denominator`` or - ``weights * (numerator/denominator)``, if ``average='weighted'``. - - Args: - numerator: A tensor with numerator numbers. - denominator: A tensor with denominator numbers. If a denominator is - negative, the class will be ignored (if averaging), or its score - will be returned as ``nan`` (if ``average=None``). - If the denominator is zero, then ``zero_division`` score will be - used for those elements. - weights: - A tensor of weights to be used if ``average='weighted'``. - average: - The method to average the scores. Should be one of ``'micro'``, ``'macro'``, - ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior - corresponds to `sklearn averaging methods `__. - mdmc_average: - The method to average the scores if inputs were multi-dimensional multi-class (MDMC). - Should be either ``'global'`` or ``'samplewise'``. If inputs were not - multi-dimensional multi-class, it should be ``None`` (default). - zero_division: - The value to use for the score if denominator equals zero. - """ - numerator, denominator = numerator.float(), denominator.float() - zero_div_mask = denominator == 0 - ignore_mask = denominator < 0 - - if weights is None: - weights = torch.ones_like(denominator) - else: - weights = weights.float() - - numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) - denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) - weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) - - if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): - weights = weights / weights.sum(dim=-1, keepdim=True) - - scores = weights * (numerator / denominator) - - # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' - scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) - - if mdmc_average == MDMCAverageMethod.SAMPLEWISE: - scores = scores.mean(dim=0) - ignore_mask = ignore_mask.sum(dim=0).bool() - - if average in (AverageMethod.NONE, None): - scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) - else: - scores = scores.sum() - - return scores diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 35a5fa85a9763..53a47611cd49a 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -14,8 +14,8 @@ from typing import Optional, Tuple import torch - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType def _accuracy_update( diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index 2017e6c5277cb..e772b5050260c 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -15,8 +15,9 @@ from typing import Optional, Sequence, Tuple import torch +from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.metrics.functional.auc import auc from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.utilities import LightningEnum diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 58947f2cb19ed..e77fc4224d25e 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -14,8 +14,9 @@ from typing import Optional import torch +from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index d288a87fc3aaf..3254dcbf8badb 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -14,8 +14,7 @@ from typing import Tuple, Union import torch - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from torchmetrics.classification.checks import _input_format_classification def _hamming_distance_update( diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 09632e216560b..b6d26237cf287 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,8 +14,8 @@ from typing import Optional import torch +from torchmetrics.classification.stat_scores import _reduce_stat_scores -from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 108cdf7a5b88a..fb1849d3805b2 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -14,8 +14,7 @@ from typing import Optional, Tuple import torch - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from torchmetrics.classification.checks import _input_format_classification def _del_column(tensor: torch.Tensor, index: int): diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index f4209f40d002e..1f1c41c6eb2f0 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -23,6 +23,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable: Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. """ + @wraps(fn) def insert_env_defaults(self, *args, **kwargs): cls = self.__class__ # get the class diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index bed60aa88388f..63a4870ed422e 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -4,9 +4,10 @@ import pytest import torch from sklearn.metrics import accuracy_score as sk_accuracy +from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from pytorch_lightning.metrics.functional import accuracy from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py index c57072c033c8c..a4db9c7f339b2 100644 --- a/tests/metrics/classification/test_hamming_distance.py +++ b/tests/metrics/classification/test_hamming_distance.py @@ -1,9 +1,9 @@ import pytest import torch from sklearn.metrics import hamming_loss as sk_hamming_loss +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import HammingDistance -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import hamming_distance from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 2b7be8caa7a0d..f07a9c2821f56 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -1,9 +1,10 @@ import pytest import torch from torch import rand, randint +from torchmetrics.classification.checks import _input_format_classification from torchmetrics.utilities.data import select_topk, to_onehot +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index a9bf39044174a..f13c1ebe26d3e 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,9 +5,9 @@ import pytest import torch from sklearn.metrics import precision_score, recall_score +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import Metric, Precision, Recall -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import precision, precision_recall, recall from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 659765931c433..6ccb5abed6711 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -5,9 +5,9 @@ import pytest import torch from sklearn.metrics import multilabel_confusion_matrix +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import StatScores -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import stat_scores from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 5150ea1a304f4..44510eb16184d 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -277,6 +277,7 @@ def test_lr_find_with_bs_scale(tmpdir): """ Test that lr_find runs with batch_size_scaling """ class BoringModelTune(BoringModel): + def __init__(self, learning_rate=0.1, batch_size=2): super().__init__() self.save_hyperparameters()