Skip to content

Commit

Permalink
document exceptions for metrics/functional (#6273)
Browse files Browse the repository at this point in the history
* document exceptions for metrics/functional

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
  • Loading branch information
4 people authored Mar 15, 2021
1 parent 156847b commit 06756a8
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 0 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/metrics/functional/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ def accuracy(
``preds = preds.flatten()`` and same for ``target``). Note that the ``top_k`` parameter
still applies in both cases, if set.
Raises:
ValueError:
If ``top_k`` parameter is set for ``multi-label`` inputs.
Example:
>>> from pytorch_lightning.metrics.functional import accuracy
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/metrics/functional/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = False) -> torch.Tensor
Return:
Tensor containing AUC score (float)
Raises:
ValueError:
If both ``x`` and ``y`` tensors are not ``1d``.
ValueError:
If both ``x`` and ``y`` don't have the same numnber of elements.
ValueError:
If ``x`` tesnsor is neither increasing or decreasing.
Example:
>>> from pytorch_lightning.metrics.functional import auc
>>> x = torch.tensor([0, 1, 2, 3])
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/metrics/functional/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,18 @@ def auroc(
range [0, max_fpr]. Should be a float between 0 and 1.
sample_weight: sample weights for each data point
Raises:
ValueError:
If ``max_fpr`` is not a ``float`` in the range ``(0, 1]``.
RuntimeError:
If ``PyTorch version`` is ``below 1.6`` since max_fpr requires `torch.bucketize`
which is not available below 1.6.
ValueError:
If ``max_fpr`` is not set to ``None`` and the mode is ``not binary``
since partial AUC computation is not available in multilabel/multiclass.
ValueError:
If ``average`` is none of ``None``, ``"macro"`` or ``"weighted"``.
Example (binary case):
>>> from pytorch_lightning.metrics.functional import auroc
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def stat_scores_multiple_classes(
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.stat_scores`
Raises:
ValueError:
If ``reduction`` is not one of ``"none"``, ``"sum"`` or ``"elementwise_mean"``.
"""

rank_zero_warn(
Expand Down Expand Up @@ -439,6 +442,16 @@ def multiclass_auroc(
Return:
Tensor containing ROCAUC score
Raises:
ValueError:
If ``pred`` don't sum up to ``1`` over classes for ``Multiclass AUROC``.
ValueError:
If number of classes found in ``target`` does not equal the number of
columns in ``pred``.
ValueError:
If number of classes deduced from ``pred`` does not equal the number of
classes passed in ``num_classes``.
Example:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/metrics/functional/image_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def image_gradients(img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
Return:
Tuple of (dy, dx) with each gradient of shape ``[N, C, H, W]``
Raises:
TypeError:
If ``img`` is not of the type <torch.Tensor>.
RuntimeError:
If ``img`` is not a 4D tensor.
Example:
>>> from pytorch_lightning.metrics.functional import image_gradients
>>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
Expand Down
36 changes: 36 additions & 0 deletions pytorch_lightning/metrics/functional/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def precision(
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
Raises:
ValueError:
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
``"samples"``, ``"none"`` or ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``average`` is set but ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
Example:
>>> from pytorch_lightning.metrics.functional import precision
Expand Down Expand Up @@ -295,6 +307,18 @@ def recall(
- If ``average in ['none', None]``, the shape will be ``(C,)``, where ``C`` stands for the number
of classes
Raises:
ValueError:
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
``"samples"``, ``"none"`` or ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``average`` is set but ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
Example:
>>> from pytorch_lightning.metrics.functional import recall
Expand Down Expand Up @@ -444,6 +468,18 @@ def precision_recall(
- If ``average in ['none', None]``, they are a tensor of shape ``(C, )``, where ``C`` stands for
the number of classes
Raises:
ValueError:
If ``average`` is not one of ``"micro"``, ``"macro"``, ``"weighted"``,
``"samples"``, ``"none"`` or ``None``.
ValueError:
If ``mdmc_average`` is not one of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``average`` is set but ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
Example:
>>> from pytorch_lightning.metrics.functional import precision_recall
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,14 @@ def precision_recall_curve(
>>> thresholds
tensor([1, 2, 3])
Raises:
ValueError:
If ``preds`` and ``target`` don't have the same number of dimensions,
or one additional dimension for ``preds``.
ValueError:
If the number of classes deduced from ``preds`` is not the same as the
``num_classes`` provided.
Example (multiclass case):
>>> from pytorch_lightning.metrics.functional import precision_recall_curve
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/metrics/functional/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def psnr(
Return:
Tensor with PSNR score
Raises:
ValueError:
If ``dim`` is not ``None`` and ``data_range`` is not provided.
Example:
>>> from pytorch_lightning.metrics.functional import psnr
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/metrics/functional/r2score.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ def r2score(
* ``'uniform_average'`` scores are uniformly averaged
* ``'variance_weighted'`` scores are weighted by their individual variances
Raises:
ValueError:
If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors.
ValueError:
If ``len(preds)`` is less than ``2``
since at least ``2`` sampels are needed to calculate r2 score.
ValueError:
If ``multioutput`` is not one of ``raw_values``,
``uniform_average`` or ``variance_weighted``.
ValueError:
If ``adjusted`` is not an ``integer`` greater than ``0``.
Example:
>>> from pytorch_lightning.metrics.functional import r2score
Expand Down
12 changes: 12 additions & 0 deletions pytorch_lightning/metrics/functional/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,18 @@ def ssim(
Return:
Tensor with SSIM score
Raises:
TypeError:
If ``preds`` and ``target`` don't have the same data type.
ValueError:
If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
ValueError:
If the length of ``kernel_size`` or ``sigma`` is not ``2``.
ValueError:
If one of the elements of ``kernel_size`` is not an ``odd positive number``.
ValueError:
If one of the elements of ``sigma`` is not a ``positive number``.
Example:
>>> from pytorch_lightning.metrics.functional import ssim
>>> preds = torch.rand([16, 1, 16, 16])
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/metrics/functional/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,21 @@ def stat_scores(
- If ``reduce='macro'``, the shape will be ``(N, C, 5)``
- If ``reduce='samples'``, the shape will be ``(N, X, 5)``
Raises:
ValueError:
If ``reduce`` is none of ``"micro"``, ``"macro"`` or ``"samples"``.
ValueError:
If ``mdmc_reduce`` is none of ``None``, ``"samplewise"``, ``"global"``.
ValueError:
If ``reduce`` is set to ``"macro"`` and ``num_classes`` is not provided.
ValueError:
If ``num_classes`` is set
and ``ignore_index`` is not in the range ``[0, num_classes)``.
ValueError:
If ``ignore_index`` is used with ``binary data``.
ValueError:
If inputs are ``multi-dimensional multi-class`` and ``mdmc_reduce`` is not provided.
Example:
>>> from pytorch_lightning.metrics.functional import stat_scores
Expand Down

0 comments on commit 06756a8

Please sign in to comment.