Skip to content

Commit

Permalink
Prune deprecated metrics for 1.3 (Lightning-AI#6161)
Browse files Browse the repository at this point in the history
* prune deprecated metrics for 1.3

* isort / yapf
  • Loading branch information
Borda authored and ananthsub committed Feb 24, 2021
1 parent 930a269 commit 2a999cd
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 356 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Added support to checkpoint after training batches in `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


### Changed

Expand All @@ -25,6 +27,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))


- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand Down Expand Up @@ -94,7 +101,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))
- Added DeepSpeed integration ([#5954](https://github.com/PyTorchLightning/pytorch-lightning/pull/5954),
[#6042](https://github.com/PyTorchLightning/pytorch-lightning/pull/6042))
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))

### Changed

Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/metrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
multiclass_auroc,
stat_scores_multiple_classes,
to_categorical,
to_onehot,
)
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
Expand Down
257 changes: 10 additions & 247 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,70 +18,11 @@

from pytorch_lightning.metrics.functional.auc import auc as __auc
from pytorch_lightning.metrics.functional.auroc import auroc as __auroc
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.iou import iou as __iou
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve as __prc
from pytorch_lightning.metrics.functional.roc import roc as __roc
from pytorch_lightning.metrics.utils import class_reduce
from pytorch_lightning.metrics.utils import get_num_classes as __gnc
from pytorch_lightning.metrics.utils import reduce
from pytorch_lightning.metrics.utils import to_categorical as __tc
from pytorch_lightning.metrics.utils import to_onehot as __to
from pytorch_lightning.metrics.utils import class_reduce, get_num_classes, reduce, to_categorical
from pytorch_lightning.utilities import rank_zero_warn


def to_onehot(
tensor: torch.Tensor,
num_classes: Optional[int] = None,
) -> torch.Tensor:
"""
Converts a dense label tensor to one-hot format
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_onehot`
"""
rank_zero_warn(
"This `to_onehot` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_onehot`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __to(tensor, num_classes)


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
"""
Converts a tensor of probabilities to a dense label tensor
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.to_categorical`
"""
rank_zero_warn(
"This `to_categorical` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import to_categorical`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __tc(tensor)


def get_num_classes(
pred: torch.Tensor,
target: torch.Tensor,
num_classes: Optional[int] = None,
) -> int:
"""
Calculates the number of classes for a given prediction and target tensor.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.utils.get_num_classes`
"""
rank_zero_warn(
"This `get_num_classes` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.utils import get_num_classes`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __gnc(pred, target, num_classes)


def stat_scores(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -122,6 +63,7 @@ def stat_scores(
return tp, fp, tn, fn, sup


# todo: remove in 1.4
def stat_scores_multiple_classes(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -210,6 +152,7 @@ def _confmat_normalize(cm):
return cm


# todo: remove in 1.4
def precision_recall(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -268,6 +211,7 @@ def precision_recall(
return precision, recall


# todo: remove in 1.4
def precision(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -311,6 +255,7 @@ def precision(
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[0]


# todo: remove in 1.4
def recall(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -353,128 +298,7 @@ def recall(
return precision_recall(pred=pred, target=target, num_classes=num_classes, class_reduction=class_reduction)[1]


# todo: remove in 1.3
def roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __roc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def _roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Example:
>>> x = torch.tensor([0, 1, 2, 3])
>>> y = torch.tensor([0, 1, 1, 1])
>>> fpr, tpr, thresholds = _roc(x, y)
>>> fpr
tensor([0., 0., 0., 0., 1.])
>>> tpr
tensor([0.0000, 0.3333, 0.6667, 1.0000, 1.0000])
>>> thresholds
tensor([4, 3, 2, 1, 0])
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
fps, tps, thresholds = _binary_clf_curve(pred, target, sample_weights=sample_weight, pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, false positive value should be meaningless")

fpr = fps / fps[-1]

if tps[-1] <= 0:
raise ValueError("No positive samples in targets, true positive value should be meaningless")

tpr = tps / tps[-1]

return fpr, tpr, thresholds


# TODO: deprecated in favor of general ROC in pytorch_lightning/metrics/functional/roc.py
def multiclass_roc(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.roc.roc`
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)
Return:
returns roc for each class.
Number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> multiclass_roc(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
rank_zero_warn(
"This `multiclass_roc` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.roc import roc`."
" It will be removed in v1.3.0", DeprecationWarning
)
num_classes = get_num_classes(pred, target, num_classes)

class_roc_vals = []
for c in range(num_classes):
pred_c = pred[:, c]

class_roc_vals.append(_roc(pred=pred_c, target=target, sample_weight=sample_weight, pos_label=c))

return tuple(class_roc_vals)


# todo: remove in 1.4
def auc(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -508,6 +332,7 @@ def auc(
return __auc(x, y)


# todo: remove in 1.4
def auc_decorator() -> Callable:
rank_zero_warn("This `auc_decorator` was deprecated in v1.2.0." " It will be removed in v1.4.0", DeprecationWarning)

Expand All @@ -524,6 +349,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
return wrapper


# todo: remove in 1.4
def multiclass_auc_decorator() -> Callable:
rank_zero_warn(
"This `multiclass_auc_decorator` was deprecated in v1.2.0."
Expand All @@ -546,6 +372,7 @@ def new_func(*args, **kwargs) -> torch.Tensor:
return wrapper


# todo: remove in 1.4
def auroc(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -588,6 +415,7 @@ def auroc(
)


# todo: remove in 1.4
def multiclass_auroc(
pred: torch.Tensor,
target: torch.Tensor,
Expand Down Expand Up @@ -767,68 +595,3 @@ def iou(
num_classes=num_classes,
reduction=reduction
)


# todo: remove in 1.3
def precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Computes precision-recall pairs for different thresholds.
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __prc(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)


# todo: remove in 1.3
def multiclass_precision_recall_curve(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
):
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
" `from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve`."
" It will be removed in v1.3.0", DeprecationWarning
)
if num_classes is None:
num_classes = get_num_classes(pred, target, num_classes)
return __prc(preds=pred, target=target, sample_weights=sample_weight, num_classes=num_classes)


# todo: remove in 1.3
def average_precision(
pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.,
):
"""
Compute average precision from prediction scores.
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.1.0 in favor of"
" `pytorch_lightning.metrics.functional.average_precision import average_precision`."
" It will be removed in v1.3.0", DeprecationWarning
)
return __ap(preds=pred, target=target, sample_weights=sample_weight, pos_label=pos_label)
3 changes: 1 addition & 2 deletions pytorch_lightning/metrics/functional/iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import torch

from pytorch_lightning.metrics.functional.confusion_matrix import _confusion_matrix_update
from pytorch_lightning.metrics.functional.reduction import reduce
from pytorch_lightning.metrics.utils import get_num_classes
from pytorch_lightning.metrics.utils import get_num_classes, reduce


def _iou_from_confmat(
Expand Down
Loading

0 comments on commit 2a999cd

Please sign in to comment.