Skip to content

Commit

Permalink
Add documentation to native metrics (#2144)
Browse files Browse the repository at this point in the history
* add docs

* add docs

* Apply suggestions from code review

* formatting

* add docs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
  • Loading branch information
3 people committed Jun 12, 2020
1 parent c3ad1ca commit 3e979d0
Showing 1 changed file with 224 additions and 13 deletions.
237 changes: 224 additions & 13 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

def to_onehot(tensor: torch.Tensor,
n_classes: Optional[int] = None) -> torch.Tensor:
""" Converts a dense label tensor to one-hot format
"""
Converts a dense label tensor to one-hot format
Args:
tensor: dense label tensor, with shape [N, d1, d2, ...]
Expand All @@ -29,14 +30,31 @@ def to_onehot(tensor: torch.Tensor,


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
""" Converts a tensor of probabilities to a dense label tensor """
"""
Converts a tensor of probabilities to a dense label tensor
Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply (default: 1)
Return:
A tensor with categorical labels [N, d2, ...]
"""
return torch.argmax(tensor, dim=argmax_dim)


def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int]) -> int:
""" Returns the number of classes for a given prediction and
target tensor
"""
Returns the number of classes for a given prediction and target tensor.
Args:
pred: predicted values
target: true labels
num_classes: number of classes if known (default: None)
Return:
An integer that represents the number of classes.
"""
if num_classes is None:
if pred.ndim > target.ndim:
Expand All @@ -50,8 +68,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
class_index: int, argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calculates the number of true postive, false postive, true negative
and false negative for a specfic class
"""
Calculates the number of true positive, falsepositivee, true negative
and false negative for a specific class
Args:
pred: prediction tensor
Expand All @@ -63,6 +82,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
Return:
Tensors in the following order: True Positive, False Positive, True Negative, False Negative
"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
Expand All @@ -80,20 +102,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class
"""
Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class
Args:
pred: prediction tensor
target: target tensor
class_index: class to calculate over
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over
Return:
Returns tensors for: tp, fp, tn, fn
"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
Expand All @@ -116,6 +139,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
def accuracy(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean') -> torch.Tensor:
"""
Computes the accuracy classification score
Args:
pred: predicted labels
target: ground truth labels
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
A Tensor with the classification score.
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes)

Expand All @@ -129,6 +169,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,

def confusion_matrix(pred: torch.Tensor, target: torch.Tensor,
normalize: bool = False) -> torch.Tensor:
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.
Args:
pred: estimated targets
target: ground truth labels
normalize: normalizes confusion matrix
Return:
Tensor, confusion matrix C [num_classes, num_classes ]
"""
num_classes = get_num_classes(pred, target, None)

d = target.size(-1)
Expand All @@ -149,6 +201,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean'
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes precision and recall for different thresholds
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision-recall values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with precision and recall
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
target=target,
num_classes=num_classes)
Expand All @@ -168,20 +237,77 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
def precision(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes precision score.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with precision.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[0]


def recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes recall score.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing recall values (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements
Return:
Tensor with recall.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[1]


def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.
Args:
pred: estimated probabilities
target: ground-truth labels
beta: weights recall when combining the score.
beta < 1: more weight to precision.
beta > 1 more weight to recall
beta = 0: only precision
beta -> inf: only recall
num_classes: number of classes
reduction: method for reducing F-score (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.
Return:
Tensor with the value of F-score. It is a value between 0-1.
"""
prec, rec = precision_recall(pred=pred, target=target,
num_classes=num_classes,
reduction='none')
Expand All @@ -196,6 +322,23 @@ def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
def f1_score(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean') -> torch.Tensor:
"""
Computes F1-score a.k.a F-measure.
Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing F1-score (default: takes the mean)
Available reduction methods:
- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.
Return:
Tensor containing F1-score
"""
return fbeta_score(pred=pred, target=target, beta=1.,
num_classes=num_classes, reduction=reduction)

Expand Down Expand Up @@ -251,6 +394,18 @@ def roc(pred: torch.Tensor, target: torch.Tensor,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1)
Return:
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
Expand Down Expand Up @@ -282,6 +437,19 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor,
) -> Tuple[Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]]:
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)
Return:
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)

class_roc_vals = []
Expand All @@ -301,6 +469,18 @@ def precision_recall_curve(pred: torch.Tensor,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Computes precision-recall pairs for different thresholds.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
Return:
[Tensor, Tensor, Tensor]: precision, recall, thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
Expand Down Expand Up @@ -334,6 +514,18 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
) -> Tuple[Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]]:
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weight
num_classes: number of classes
Return:
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)

class_pr_vals = []
Expand All @@ -350,6 +542,17 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,


def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule
Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing.
Return:
AUC score (float)
"""
direction = 1.

if reorder:
Expand Down Expand Up @@ -400,6 +603,15 @@ def new_func(*args, **kwargs) -> torch.Tensor:
def auroc(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores
Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
"""
return roc(pred=pred, target=target, sample_weight=sample_weight,
pos_label=pos_label)

Expand All @@ -410,7 +622,6 @@ def average_precision(pred: torch.Tensor, target: torch.Tensor,
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
Expand Down

0 comments on commit 3e979d0

Please sign in to comment.