Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation to native metrics #2144

Merged
merged 5 commits into from
Jun 12, 2020
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 104 additions & 2 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,30 @@ def to_onehot(tensor: torch.Tensor,


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
""" Converts a tensor of probabilities to a dense label tensor """
""" Converts a tensor of probabilities to a dense label tensor

Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply (default: 1)

Return:
A tensor with categorical labels [N, d2, ...]
"""
return torch.argmax(tensor, dim=argmax_dim)


def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int]) -> int:
""" Returns the number of classes for a given prediction and
target tensor
target tensor,

Args:
pred: predicted values
target: true labels
num_classes: number of classes if known (default: None)

Return:
An integer that represents the number of classes.
"""
if num_classes is None:
if pred.ndim > target.ndim:
Expand All @@ -63,6 +79,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over

Return:
Tensors in the following order: True Positive, False Positive, True Negative, False Negative

"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
Expand Down Expand Up @@ -94,6 +113,9 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over

Return:
Returns tensors for: tp, fp, tn, fn

"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
Expand All @@ -116,6 +138,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
def accuracy(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean') -> torch.Tensor:
'''
Borda marked this conversation as resolved.
Show resolved Hide resolved
Computes the accuracy classification score

Args:
pred: predicted labels
target: ground truth labels
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
A Tensor with the classification score.
'''
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes)

Expand All @@ -129,6 +168,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,

def confusion_matrix(pred: torch.Tensor, target: torch.Tensor,
normalize: bool = False) -> torch.Tensor:
'''
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.

Args:
pred: estimated targets
target: ground truth labels
normalize: normalizes confusion matrix

Return:
Tensor, confusion matrix C [num_classes, num_classes ]
'''
num_classes = get_num_classes(pred, target, None)

d = target.size(-1)
Expand All @@ -149,6 +200,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean'
) -> Tuple[torch.Tensor, torch.Tensor]:
'''
Computes precision and recall for different thresholds

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision-recall values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with precision and recall
'''
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
target=target,
num_classes=num_classes)
Expand All @@ -168,13 +236,47 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
def precision(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
'''
Computes precision score.

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with precision.
'''
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[0]


def recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
'''
Computes recall score.

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing recall values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with recall.
'''
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[1]

Expand Down