From aea15e8bf48c88ef104948c426195acbed142252 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:49:20 +0100 Subject: [PATCH 01/13] _basic_input_validation --- .../metrics/classification/helpers.py | 32 +------------------ 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index a91150799d5a1..71890f764bf34 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -15,6 +15,7 @@ import numpy as np import torch +from torchmetrics.classification.checks import _basic_input_validation from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -52,37 +53,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _basic_input_validation(preds: torch.Tensor, target: torch.Tensor, threshold: float, is_multiclass: bool): - """ - Perform basic validation of inputs that does not require deducing any information - of the type of inputs. - """ - - if target.is_floating_point(): - raise ValueError("The `target` has to be an integer tensor.") - if target.min() < 0: - raise ValueError("The `target` has to be a non-negative tensor.") - - preds_float = preds.is_floating_point() - if not preds_float and preds.min() < 0: - raise ValueError("If `preds` are integers, they have to be non-negative.") - - if not preds.shape[0] == target.shape[0]: - raise ValueError("The `preds` and `target` should have the same first dimension.") - - if preds_float and (preds.min() < 0 or preds.max() > 1): - raise ValueError("The `preds` should be probabilities, but values were detected outside of [0,1] range.") - - if not 0 < threshold < 1: - raise ValueError(f"The `threshold` should be a float in the (0,1) interval, got {threshold}") - - if is_multiclass is False and target.max() > 1: - raise ValueError("If you set `is_multiclass=False`, then `target` should not exceed 1.") - - if is_multiclass is False and not preds_float and preds.max() > 1: - raise ValueError("If you set `is_multiclass=False` and `preds` are integers, then `preds` should not exceed 1.") - - def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: """ This checks that the shape and type of inputs are consistent with From 2f9138c2b718f6b7fe40e79f8b5bbe4943c27a18 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:50:01 +0100 Subject: [PATCH 02/13] _check_shape_and_type_consistency --- .../metrics/classification/helpers.py | 64 +------------------ 1 file changed, 1 insertion(+), 63 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 71890f764bf34..888b611bf01f1 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -15,7 +15,7 @@ import numpy as np import torch -from torchmetrics.classification.checks import _basic_input_validation +from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -53,68 +53,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_shape_and_type_consistency(preds: torch.Tensor, target: torch.Tensor) -> Tuple[str, int]: - """ - This checks that the shape and type of inputs are consistent with - each other and fall into one of the allowed input types (see the - documentation of docstring of ``_input_format_classification``). It does - not check for consistency of number of classes, other functions take - care of that. - - It returns the name of the case in which the inputs fall, and the implied - number of classes (from the ``C`` dim for multi-class data, or extra dim(s) for - multi-label data). - """ - - preds_float = preds.is_floating_point() - - if preds.ndim == target.ndim: - if preds.shape != target.shape: - raise ValueError( - "The `preds` and `target` should have the same shape,", - f" got `preds` with shape={preds.shape} and `target` with shape={target.shape}.", - ) - if preds_float and target.max() > 1: - raise ValueError( - "If `preds` and `target` are of shape (N, ...) and `preds` are floats, `target` should be binary." - ) - - # Get the case - if preds.ndim == 1 and preds_float: - case = DataType.BINARY - elif preds.ndim == 1 and not preds_float: - case = DataType.MULTICLASS - elif preds.ndim > 1 and preds_float: - case = DataType.MULTILABEL - else: - case = DataType.MULTIDIM_MULTICLASS - - implied_classes = preds[0].numel() - - elif preds.ndim == target.ndim + 1: - if not preds_float: - raise ValueError("If `preds` have one dimension more than `target`, `preds` should be a float tensor.") - if preds.shape[2:] != target.shape[1:]: - raise ValueError( - "If `preds` have one dimension more than `target`, the shape of `preds` should be" - " (N, C, ...), and the shape of `target` should be (N, ...)." - ) - - implied_classes = preds.shape[1] - - if preds.ndim == 2: - case = DataType.MULTICLASS - else: - case = DataType.MULTIDIM_MULTICLASS - else: - raise ValueError( - "Either `preds` and `target` both should have the (same) shape (N, ...), or `target` should be (N, ...)" - " and `preds` should be (N, C, ...)." - ) - - return case, implied_classes - - def _check_num_classes_binary(num_classes: int, is_multiclass: bool): """ This checks that the consistency of `num_classes` with the data From 301491d8c5594e6a9425633e97cfa0482cb0ad73 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:50:48 +0100 Subject: [PATCH 03/13] _check_num_classes_binary --- .../metrics/classification/helpers.py | 24 ++----------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 888b611bf01f1..bd1496c965ff4 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -15,7 +15,8 @@ import numpy as np import torch -from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency +from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency, \ + _check_num_classes_binary from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -53,27 +54,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_num_classes_binary(num_classes: int, is_multiclass: bool): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for binary data. - """ - - if num_classes > 2: - raise ValueError("Your data is binary, but `num_classes` is larger than 2.") - if num_classes == 2 and not is_multiclass: - raise ValueError( - "Your data is binary and `num_classes=2`, but `is_multiclass` is not True." - " Set it to True if you want to transform binary data to multi-class format." - ) - if num_classes == 1 and is_multiclass: - raise ValueError( - "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either set `is_multiclass=None`(default) or set `num_classes=2`" - " to transform binary data to multi-class format." - ) - - def _check_num_classes_mc( preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int ): From ece1aa2dea2000c84849e8cedcaa0614dc1d5e5e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:51:18 +0100 Subject: [PATCH 04/13] _check_num_classes_mc --- .../metrics/classification/helpers.py | 34 +------------------ 1 file changed, 1 insertion(+), 33 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index bd1496c965ff4..81bc513348870 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -16,7 +16,7 @@ import numpy as np import torch from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency, \ - _check_num_classes_binary + _check_num_classes_binary, _check_num_classes_mc from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -54,38 +54,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_num_classes_mc( - preds: torch.Tensor, target: torch.Tensor, num_classes: int, is_multiclass: bool, implied_classes: int -): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for (multi-dimensional) multi-class data. - """ - - if num_classes == 1 and is_multiclass is not False: - raise ValueError( - "You have set `num_classes=1`, but predictions are integers." - " If you want to convert (multi-dimensional) multi-class data with 2 classes" - " to binary/multi-label, set `is_multiclass=False`." - ) - if num_classes > 1: - if is_multiclass is False: - if implied_classes != num_classes: - raise ValueError( - "You have set `is_multiclass=False`, but the implied number of classes " - " (from shape of inputs) does not match `num_classes`. If you are trying to" - " transform multi-dim multi-class data with 2 classes to multi-label, `num_classes`" - " should be either None or the product of the size of extra dimensions (...)." - " See Input Types in Metrics documentation." - ) - if num_classes <= target.max(): - raise ValueError("The highest label in `target` should be smaller than `num_classes`.") - if num_classes <= preds.max(): - raise ValueError("The highest label in `preds` should be smaller than `num_classes`.") - if preds.shape != target.shape and num_classes != implied_classes: - raise ValueError("The size of C dimension of `preds` does not match `num_classes`.") - - def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): """ This checks that the consistency of `num_classes` with the data From b4b2c2bbc1d67e1db1eeeb380e6bd8d590d59ba5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:51:39 +0100 Subject: [PATCH 05/13] _check_num_classes_ml --- .../metrics/classification/helpers.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 81bc513348870..0de507dc46b8d 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -16,7 +16,7 @@ import numpy as np import torch from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency, \ - _check_num_classes_binary, _check_num_classes_mc + _check_num_classes_binary, _check_num_classes_mc, _check_num_classes_ml from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -54,22 +54,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_num_classes_ml(num_classes: int, is_multiclass: bool, implied_classes: int): - """ - This checks that the consistency of `num_classes` with the data - and `is_multiclass` param for multi-label data. - """ - - if is_multiclass and num_classes != 2: - raise ValueError( - "Your have set `is_multiclass=True`, but `num_classes` is not equal to 2." - " If you are trying to transform multi-label data to 2 class multi-dimensional" - " multi-class, you should set `num_classes` to either 2 or None." - ) - if not is_multiclass and num_classes != implied_classes: - raise ValueError("The implied number of classes (from shape of inputs) does not match num_classes.") - - def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): if case == DataType.BINARY: raise ValueError("You can not use `top_k` parameter with binary data.") From e0fe1e146318e7d6f83fecc1818b65a0ec881378 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:52:09 +0100 Subject: [PATCH 06/13] _check_top_k --- .../metrics/classification/helpers.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 0de507dc46b8d..6bc51cd43db9f 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -16,7 +16,7 @@ import numpy as np import torch from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency, \ - _check_num_classes_binary, _check_num_classes_mc, _check_num_classes_ml + _check_num_classes_binary, _check_num_classes_mc, _check_num_classes_ml, _check_top_k from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -54,24 +54,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_top_k(top_k: int, case: str, implied_classes: int, is_multiclass: Optional[bool], preds_float: bool): - if case == DataType.BINARY: - raise ValueError("You can not use `top_k` parameter with binary data.") - if not isinstance(top_k, int) or top_k <= 0: - raise ValueError("The `top_k` has to be an integer larger than 0.") - if not preds_float: - raise ValueError("You have set `top_k`, but you do not have probability predictions.") - if is_multiclass is False: - raise ValueError("If you set `is_multiclass=False`, you can not set `top_k`.") - if case == DataType.MULTILABEL and is_multiclass: - raise ValueError( - "If you want to transform multi-label data to 2 class multi-dimensional" - "multi-class data using `is_multiclass=True`, you can not use `top_k`." - ) - if top_k >= implied_classes: - raise ValueError("The `top_k` has to be strictly smaller than the `C` dimension of `preds`.") - - def _check_classification_inputs( preds: torch.Tensor, target: torch.Tensor, From 76b340a3001d667a8eae140523893c5d7c27279a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:53:01 +0100 Subject: [PATCH 07/13] _check_classification_inputs --- .../metrics/classification/helpers.py | 100 +----------------- 1 file changed, 1 insertion(+), 99 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 6bc51cd43db9f..5d28940edc0a8 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -15,8 +15,7 @@ import numpy as np import torch -from torchmetrics.classification.checks import _basic_input_validation, _check_shape_and_type_consistency, \ - _check_num_classes_binary, _check_num_classes_mc, _check_num_classes_ml, _check_top_k +from torchmetrics.classification.checks import _check_classification_inputs from torchmetrics.utilities.data import select_topk, to_onehot from pytorch_lightning.utilities import LightningEnum @@ -54,103 +53,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _check_classification_inputs( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float, - num_classes: Optional[int], - is_multiclass: bool, - top_k: Optional[int], -) -> str: - """Performs error checking on inputs for classification. - - This ensures that preds and target take one of the shape/type combinations that are - specified in ``_input_format_classification`` docstring. It also checks the cases of - over-rides with ``is_multiclass`` by checking (for multi-class and multi-dim multi-class - cases) that there are only up to 2 distinct labels. - - In case where preds are floats (probabilities), it is checked whether they are in [0,1] interval. - - When ``num_classes`` is given, it is checked that it is consitent with input cases (binary, - multi-label, ...), and that, if availible, the implied number of classes in the ``C`` - dimension is consistent with it (as well as that max label in target is smaller than it). - - When ``num_classes`` is not specified in these cases, consistency of the highest target - value against ``C`` dimension is checked for (multi-dimensional) multi-class cases. - - If ``top_k`` is set (not None) for inputs that do not have probability predictions (and - are not binary), an error is raised. Similarly if ``top_k`` is set to a number that - is higher than or equal to the ``C`` dimension of ``preds``, an error is raised. - - Preds and target tensors are expected to be squeezed already - all dimensions should be - greater than 1, except perhaps the first one (``N``). - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0,1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for inputs with probability predictions. The default value (``None``) will be - interepreted as 1 for these inputs. If this parameter is set for multi-label inputs, - it will take precedence over threshold. - - Should be left unset (``None``) for inputs with label predictions. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - - Return: - case: The case the inputs fall in, one of 'binary', 'multi-class', 'multi-label' or - 'multi-dim multi-class' - """ - - # Baisc validation (that does not need case/type information) - _basic_input_validation(preds, target, threshold, is_multiclass) - - # Check that shape/types fall into one of the cases - case, implied_classes = _check_shape_and_type_consistency(preds, target) - - # For (multi-dim) multi-class case with prob preds, check that preds sum up to 1 - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and preds.is_floating_point(): - if not torch.isclose(preds.sum(dim=1), torch.ones_like(preds.sum(dim=1))).all(): - raise ValueError("Probabilities in `preds` must sum up to 1 accross the `C` dimension.") - - # Check consistency with the `C` dimension in case of multi-class data - if preds.shape != target.shape: - if is_multiclass is False and implied_classes != 2: - raise ValueError( - "You have set `is_multiclass=False`, but have more than 2 classes in your data," - " based on the C dimension of `preds`." - ) - if target.max() >= implied_classes: - raise ValueError( - "The highest label in `target` should be smaller than the size of the `C` dimension of `preds`." - ) - - # Check that num_classes is consistent - if num_classes: - if case == DataType.BINARY: - _check_num_classes_binary(num_classes, is_multiclass) - elif case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS): - _check_num_classes_mc(preds, target, num_classes, is_multiclass, implied_classes) - elif case.MULTILABEL: - _check_num_classes_ml(num_classes, is_multiclass, implied_classes) - - # Check that top_k is consistent - if top_k is not None: - _check_top_k(top_k, case, implied_classes, is_multiclass, preds.is_floating_point()) - - return case - - def _input_format_classification( preds: torch.Tensor, target: torch.Tensor, From 00f3b5bfbc1af9e75a3012f07ae1633b55d257a8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:55:53 +0100 Subject: [PATCH 08/13] _input_format_classification --- .../metrics/classification/helpers.py | 140 ------------------ .../metrics/functional/accuracy.py | 3 +- pytorch_lightning/metrics/functional/auroc.py | 3 +- .../metrics/functional/confusion_matrix.py | 3 +- .../metrics/functional/hamming_distance.py | 3 +- .../metrics/functional/stat_scores.py | 3 +- tests/metrics/classification/test_accuracy.py | 3 +- .../classification/test_hamming_distance.py | 2 +- tests/metrics/classification/test_inputs.py | 3 +- .../classification/test_precision_recall.py | 2 +- .../classification/test_stat_scores.py | 2 +- 11 files changed, 15 insertions(+), 152 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 5d28940edc0a8..27ab8980c4c74 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -53,146 +53,6 @@ class MDMCAverageMethod(LightningEnum): SAMPLEWISE = "samplewise" -def _input_format_classification( - preds: torch.Tensor, - target: torch.Tensor, - threshold: float = 0.5, - top_k: Optional[int] = None, - num_classes: Optional[int] = None, - is_multiclass: Optional[bool] = None, -) -> Tuple[torch.Tensor, torch.Tensor, str]: - """Convert preds and target tensors into common format. - - Preds and targets are supposed to fall into one of these categories (and are - validated to make sure this is the case): - - * Both preds and target are of shape ``(N,)``, and both are integers (multi-class) - * Both preds and target are of shape ``(N,)``, and target is binary, while preds - are a float (binary) - * preds are of shape ``(N, C)`` and are floats, and target is of shape ``(N,)`` and - is integer (multi-class) - * preds and target are of shape ``(N, ...)``, target is binary and preds is a float - (multi-label) - * preds are of shape ``(N, C, ...)`` and are floats, target is of shape ``(N, ...)`` - and is integer (multi-dimensional multi-class) - * preds and target are of shape ``(N, ...)`` both are integers (multi-dimensional - multi-class) - - To avoid ambiguities, all dimensions of size 1, except the first one, are squeezed out. - - The returned output tensors will be binary tensors of the same shape, either ``(N, C)`` - of ``(N, C, X)``, the details for each case are described below. The function also returns - a ``case`` string, which describes which of the above cases the inputs belonged to - regardless - of whether this was "overridden" by other settings (like ``is_multiclass``). - - In binary case, targets are normally returned as ``(N,1)`` tensor, while preds are transformed - into a binary tensor (elements become 1 if the probability is greater than or equal to - ``threshold`` or 0 otherwise). If ``is_multiclass=True``, then then both targets are preds - become ``(N, 2)`` tensors by a one-hot transformation; with the thresholding being applied to - preds first. - - In multi-class case, normally both preds and targets become ``(N, C)`` binary tensors; targets - by a one-hot transformation and preds by selecting ``top_k`` largest entries (if their original - shape was ``(N,C)``). However, if ``is_multiclass=False``, then targets and preds will be - returned as ``(N,1)`` tensor. - - In multi-label case, normally targets and preds are returned as ``(N, C)`` binary tensors, with - preds being binarized as in the binary case. Here the ``C`` dimension is obtained by flattening - all dimensions after the first one. However if ``is_multiclass=True``, then both are returned as - ``(N, 2, C)``, by an equivalent transformation as in the binary case. - - In multi-dimensional multi-class case, normally both target and preds are returned as - ``(N, C, X)`` tensors, with ``X`` resulting from flattening of all dimensions except ``N`` and - ``C``. The transformations performed here are equivalent to the multi-class case. However, if - ``is_multiclass=False`` (and there are up to two classes), then the data is returned as - ``(N, X)`` binary tensors (multi-label). - - Note that where a one-hot transformation needs to be performed and the number of classes - is not implicitly given by a ``C`` dimension, the new ``C`` dimension will either be - equal to ``num_classes``, if it is given, or the maximum label value in preds and - target. - - Args: - preds: Tensor with predictions (labels or probabilities) - target: Tensor with ground truth labels, always integers (labels) - threshold: - Threshold probability value for transforming probability predictions to binary - (0 or 1) predictions, in the case of binary or multi-label inputs. - num_classes: - Number of classes. If not explicitly set, the number of classes will be infered - either from the shape of inputs, or the maximum label in the ``target`` and ``preds`` - tensor, where applicable. - top_k: - Number of highest probability entries for each sample to convert to 1s - relevant - only for (multi-dimensional) multi-class inputs with probability predictions. The - default value (``None``) will be interepreted as 1 for these inputs. - - Should be left unset (``None``) for all other types of inputs. - is_multiclass: - Used only in certain special cases, where you want to treat inputs as a different type - than what they appear to be. - - - Returns: - preds: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - target: binary tensor of shape ``(N, C)`` or ``(N, C, X)`` - case: The case the inputs fall in, one of ``'binary'``, ``'multi-class'``, ``'multi-label'`` or - ``'multi-dim multi-class'`` - """ - # Remove excess dimensions - if preds.shape[0] == 1: - preds, target = preds.squeeze().unsqueeze(0), target.squeeze().unsqueeze(0) - else: - preds, target = preds.squeeze(), target.squeeze() - - # Convert half precision tensors to full precision, as not all ops are supported - # for example, min() is not supported - if preds.dtype == torch.float16: - preds = preds.float() - - case = _check_classification_inputs( - preds, - target, - threshold=threshold, - num_classes=num_classes, - is_multiclass=is_multiclass, - top_k=top_k, - ) - - if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: - preds = (preds >= threshold).int() - num_classes = num_classes if not is_multiclass else 2 - - if case == DataType.MULTILABEL and top_k: - preds = select_topk(preds, top_k) - - if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or is_multiclass: - if preds.is_floating_point(): - num_classes = preds.shape[1] - preds = select_topk(preds, top_k or 1) - else: - num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 - preds = to_onehot(preds, max(2, num_classes)) - - target = to_onehot(target, max(2, num_classes)) - - if is_multiclass is False: - preds, target = preds[:, 1, ...], target[:, 1, ...] - - if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and is_multiclass is not False) or is_multiclass: - target = target.reshape(target.shape[0], target.shape[1], -1) - preds = preds.reshape(preds.shape[0], preds.shape[1], -1) - else: - target = target.reshape(target.shape[0], -1) - preds = preds.reshape(preds.shape[0], -1) - - # Some operatins above create an extra dimension for MC/binary case - this removes it - if preds.ndim > 2: - preds, target = preds.squeeze(-1), target.squeeze(-1) - - return preds.int(), target.int(), case - - def _reduce_stat_scores( numerator: torch.Tensor, denominator: torch.Tensor, diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index 35a5fa85a9763..dea8b5590174c 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -14,8 +14,9 @@ from typing import Optional, Tuple import torch +from torchmetrics.classification.checks import _input_format_classification -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.classification.helpers import DataType def _accuracy_update( diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index 2017e6c5277cb..c496d86b1bbca 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -15,8 +15,9 @@ from typing import Optional, Sequence, Tuple import torch +from torchmetrics.classification.checks import _input_format_classification -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.metrics.functional.auc import auc from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.utilities import LightningEnum diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 58947f2cb19ed..5753d3e00c218 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -14,8 +14,9 @@ from typing import Optional import torch +from torchmetrics.classification.checks import _input_format_classification -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.utilities import rank_zero_warn diff --git a/pytorch_lightning/metrics/functional/hamming_distance.py b/pytorch_lightning/metrics/functional/hamming_distance.py index d288a87fc3aaf..3254dcbf8badb 100644 --- a/pytorch_lightning/metrics/functional/hamming_distance.py +++ b/pytorch_lightning/metrics/functional/hamming_distance.py @@ -14,8 +14,7 @@ from typing import Tuple, Union import torch - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from torchmetrics.classification.checks import _input_format_classification def _hamming_distance_update( diff --git a/pytorch_lightning/metrics/functional/stat_scores.py b/pytorch_lightning/metrics/functional/stat_scores.py index 108cdf7a5b88a..fb1849d3805b2 100644 --- a/pytorch_lightning/metrics/functional/stat_scores.py +++ b/pytorch_lightning/metrics/functional/stat_scores.py @@ -14,8 +14,7 @@ from typing import Optional, Tuple import torch - -from pytorch_lightning.metrics.classification.helpers import _input_format_classification +from torchmetrics.classification.checks import _input_format_classification def _del_column(tensor: torch.Tensor, index: int): diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index bed60aa88388f..63491c619d6ba 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -4,9 +4,10 @@ import pytest import torch from sklearn.metrics import accuracy_score as sk_accuracy +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.metrics.functional import accuracy from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_hamming_distance.py b/tests/metrics/classification/test_hamming_distance.py index c57072c033c8c..a4db9c7f339b2 100644 --- a/tests/metrics/classification/test_hamming_distance.py +++ b/tests/metrics/classification/test_hamming_distance.py @@ -1,9 +1,9 @@ import pytest import torch from sklearn.metrics import hamming_loss as sk_hamming_loss +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import HammingDistance -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import hamming_distance from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index 2b7be8caa7a0d..e143379259590 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -1,9 +1,10 @@ import pytest import torch from torch import rand, randint +from torchmetrics.classification.checks import _input_format_classification from torchmetrics.utilities.data import select_topk, to_onehot -from pytorch_lightning.metrics.classification.helpers import _input_format_classification, DataType +from pytorch_lightning.metrics.classification.helpers import DataType from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index a9bf39044174a..f13c1ebe26d3e 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -5,9 +5,9 @@ import pytest import torch from sklearn.metrics import precision_score, recall_score +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import Metric, Precision, Recall -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import precision, precision_recall, recall from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_stat_scores.py b/tests/metrics/classification/test_stat_scores.py index 659765931c433..6ccb5abed6711 100644 --- a/tests/metrics/classification/test_stat_scores.py +++ b/tests/metrics/classification/test_stat_scores.py @@ -5,9 +5,9 @@ import pytest import torch from sklearn.metrics import multilabel_confusion_matrix +from torchmetrics.classification.checks import _input_format_classification from pytorch_lightning.metrics import StatScores -from pytorch_lightning.metrics.classification.helpers import _input_format_classification from pytorch_lightning.metrics.functional import stat_scores from tests.metrics.classification.inputs import _input_binary, _input_binary_prob, _input_multiclass from tests.metrics.classification.inputs import _input_multiclass_prob as _input_mccls_prob From facccd5faba590ec54247b839c4c4b144f994b6b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:56:47 +0100 Subject: [PATCH 09/13] _reduce_stat_scores --- .../metrics/classification/helpers.py | 68 +------------------ .../metrics/functional/precision_recall.py | 2 +- 2 files changed, 2 insertions(+), 68 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index 27ab8980c4c74..c196df8b0bc3b 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -50,70 +50,4 @@ class MDMCAverageMethod(LightningEnum): """ GLOBAL = "global" - SAMPLEWISE = "samplewise" - - -def _reduce_stat_scores( - numerator: torch.Tensor, - denominator: torch.Tensor, - weights: Optional[torch.Tensor], - average: str, - mdmc_average: Optional[str], - zero_division: int = 0, -) -> torch.Tensor: - """ - Reduces scores of type ``numerator/denominator`` or - ``weights * (numerator/denominator)``, if ``average='weighted'``. - - Args: - numerator: A tensor with numerator numbers. - denominator: A tensor with denominator numbers. If a denominator is - negative, the class will be ignored (if averaging), or its score - will be returned as ``nan`` (if ``average=None``). - If the denominator is zero, then ``zero_division`` score will be - used for those elements. - weights: - A tensor of weights to be used if ``average='weighted'``. - average: - The method to average the scores. Should be one of ``'micro'``, ``'macro'``, - ``'weighted'``, ``'none'``, ``None`` or ``'samples'``. The behavior - corresponds to `sklearn averaging methods `__. - mdmc_average: - The method to average the scores if inputs were multi-dimensional multi-class (MDMC). - Should be either ``'global'`` or ``'samplewise'``. If inputs were not - multi-dimensional multi-class, it should be ``None`` (default). - zero_division: - The value to use for the score if denominator equals zero. - """ - numerator, denominator = numerator.float(), denominator.float() - zero_div_mask = denominator == 0 - ignore_mask = denominator < 0 - - if weights is None: - weights = torch.ones_like(denominator) - else: - weights = weights.float() - - numerator = torch.where(zero_div_mask, torch.tensor(float(zero_division), device=numerator.device), numerator) - denominator = torch.where(zero_div_mask | ignore_mask, torch.tensor(1.0, device=denominator.device), denominator) - weights = torch.where(ignore_mask, torch.tensor(0.0, device=weights.device), weights) - - if average not in (AverageMethod.MICRO, AverageMethod.NONE, None): - weights = weights / weights.sum(dim=-1, keepdim=True) - - scores = weights * (numerator / denominator) - - # This is in case where sum(weights) = 0, which happens if we ignore the only present class with average='weighted' - scores = torch.where(torch.isnan(scores), torch.tensor(float(zero_division), device=scores.device), scores) - - if mdmc_average == MDMCAverageMethod.SAMPLEWISE: - scores = scores.mean(dim=0) - ignore_mask = ignore_mask.sum(dim=0).bool() - - if average in (AverageMethod.NONE, None): - scores = torch.where(ignore_mask, torch.tensor(np.nan, device=scores.device), scores) - else: - scores = scores.sum() - - return scores + SAMPLEWISE = "samplewise" \ No newline at end of file diff --git a/pytorch_lightning/metrics/functional/precision_recall.py b/pytorch_lightning/metrics/functional/precision_recall.py index 09632e216560b..b6d26237cf287 100644 --- a/pytorch_lightning/metrics/functional/precision_recall.py +++ b/pytorch_lightning/metrics/functional/precision_recall.py @@ -14,8 +14,8 @@ from typing import Optional import torch +from torchmetrics.classification.stat_scores import _reduce_stat_scores -from pytorch_lightning.metrics.classification.helpers import _reduce_stat_scores from pytorch_lightning.metrics.functional.stat_scores import _stat_scores_update from pytorch_lightning.utilities import rank_zero_warn From b192262090600c876355e4ff7a82f3bc75ed3bcf Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:58:15 +0100 Subject: [PATCH 10/13] DataType --- pytorch_lightning/metrics/classification/helpers.py | 11 ----------- pytorch_lightning/metrics/functional/accuracy.py | 3 +-- pytorch_lightning/metrics/functional/auroc.py | 2 +- .../metrics/functional/confusion_matrix.py | 2 +- tests/metrics/classification/test_accuracy.py | 2 +- tests/metrics/classification/test_inputs.py | 2 +- 6 files changed, 5 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index c196df8b0bc3b..c9a308d8b32d0 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -21,17 +21,6 @@ from pytorch_lightning.utilities import LightningEnum -class DataType(LightningEnum): - """ - Enum to represent data type - """ - - BINARY = "binary" - MULTILABEL = "multi-label" - MULTICLASS = "multi-class" - MULTIDIM_MULTICLASS = "multi-dim multi-class" - - class AverageMethod(LightningEnum): """ Enum to represent average method diff --git a/pytorch_lightning/metrics/functional/accuracy.py b/pytorch_lightning/metrics/functional/accuracy.py index dea8b5590174c..53a47611cd49a 100644 --- a/pytorch_lightning/metrics/functional/accuracy.py +++ b/pytorch_lightning/metrics/functional/accuracy.py @@ -15,8 +15,7 @@ import torch from torchmetrics.classification.checks import _input_format_classification - -from pytorch_lightning.metrics.classification.helpers import DataType +from torchmetrics.utilities.enums import DataType def _accuracy_update( diff --git a/pytorch_lightning/metrics/functional/auroc.py b/pytorch_lightning/metrics/functional/auroc.py index c496d86b1bbca..e772b5050260c 100644 --- a/pytorch_lightning/metrics/functional/auroc.py +++ b/pytorch_lightning/metrics/functional/auroc.py @@ -16,8 +16,8 @@ import torch from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.metrics.functional.auc import auc from pytorch_lightning.metrics.functional.roc import roc from pytorch_lightning.utilities import LightningEnum diff --git a/pytorch_lightning/metrics/functional/confusion_matrix.py b/pytorch_lightning/metrics/functional/confusion_matrix.py index 5753d3e00c218..e77fc4224d25e 100644 --- a/pytorch_lightning/metrics/functional/confusion_matrix.py +++ b/pytorch_lightning/metrics/functional/confusion_matrix.py @@ -15,8 +15,8 @@ import torch from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.utilities import rank_zero_warn diff --git a/tests/metrics/classification/test_accuracy.py b/tests/metrics/classification/test_accuracy.py index 63491c619d6ba..63a4870ed422e 100644 --- a/tests/metrics/classification/test_accuracy.py +++ b/tests/metrics/classification/test_accuracy.py @@ -5,9 +5,9 @@ import torch from sklearn.metrics import accuracy_score as sk_accuracy from torchmetrics.classification.checks import _input_format_classification +from torchmetrics.utilities.enums import DataType from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.metrics.classification.helpers import DataType from pytorch_lightning.metrics.functional import accuracy from tests.metrics.classification.inputs import _input_binary, _input_binary_prob from tests.metrics.classification.inputs import _input_multiclass as _input_mcls diff --git a/tests/metrics/classification/test_inputs.py b/tests/metrics/classification/test_inputs.py index e143379259590..f07a9c2821f56 100644 --- a/tests/metrics/classification/test_inputs.py +++ b/tests/metrics/classification/test_inputs.py @@ -3,8 +3,8 @@ from torch import rand, randint from torchmetrics.classification.checks import _input_format_classification from torchmetrics.utilities.data import select_topk, to_onehot +from torchmetrics.utilities.enums import DataType -from pytorch_lightning.metrics.classification.helpers import DataType from tests.metrics.classification.inputs import _input_binary as _bin from tests.metrics.classification.inputs import _input_binary_prob as _bin_prob from tests.metrics.classification.inputs import _input_multiclass as _mc From b5ee173925cde07cbd2cf6627a69be23262c5333 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 11:59:28 +0100 Subject: [PATCH 11/13] rest --- .../metrics/classification/helpers.py | 42 ------------------- 1 file changed, 42 deletions(-) delete mode 100644 pytorch_lightning/metrics/classification/helpers.py diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py deleted file mode 100644 index c9a308d8b32d0..0000000000000 --- a/pytorch_lightning/metrics/classification/helpers.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Optional, Tuple - -import numpy as np -import torch -from torchmetrics.classification.checks import _check_classification_inputs -from torchmetrics.utilities.data import select_topk, to_onehot - -from pytorch_lightning.utilities import LightningEnum - - -class AverageMethod(LightningEnum): - """ - Enum to represent average method - """ - - MICRO = "micro" - MACRO = "macro" - WEIGHTED = "weighted" - NONE = "none" - SAMPLES = "samples" - - -class MDMCAverageMethod(LightningEnum): - """ - Enum to represent multi-dim multi-class average method - """ - - GLOBAL = "global" - SAMPLEWISE = "samplewise" \ No newline at end of file From 19dd0810ed5208240ec25ab359b71e2cba50dc72 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 12:01:03 +0100 Subject: [PATCH 12/13] flake8 --- pytorch_lightning/trainer/connectors/env_vars_connector.py | 1 + tests/trainer/test_lr_finder.py | 1 + 2 files changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index f4209f40d002e..1f1c41c6eb2f0 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -23,6 +23,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable: Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. """ + @wraps(fn) def insert_env_defaults(self, *args, **kwargs): cls = self.__class__ # get the class diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 5150ea1a304f4..44510eb16184d 100644 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -277,6 +277,7 @@ def test_lr_find_with_bs_scale(tmpdir): """ Test that lr_find runs with batch_size_scaling """ class BoringModelTune(BoringModel): + def __init__(self, learning_rate=0.1, batch_size=2): super().__init__() self.save_hyperparameters() From d02740cbf0d1b11427506efc59fe834828a7464a Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 12:03:16 +0100 Subject: [PATCH 13/13] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f5dcb20375137..c17f913cc960d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), + + [#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547), )