Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add documentation to native metrics #2144

Merged
merged 5 commits into from
Jun 12, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 224 additions & 13 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

def to_onehot(tensor: torch.Tensor,
n_classes: Optional[int] = None) -> torch.Tensor:
""" Converts a dense label tensor to one-hot format
"""
Converts a dense label tensor to one-hot format

Args:
tensor: dense label tensor, with shape [N, d1, d2, ...]
Expand All @@ -29,14 +30,31 @@ def to_onehot(tensor: torch.Tensor,


def to_categorical(tensor: torch.Tensor, argmax_dim: int = 1) -> torch.Tensor:
""" Converts a tensor of probabilities to a dense label tensor """
"""
Converts a tensor of probabilities to a dense label tensor

Args:
tensor: probabilities to get the categorical label [N, d1, d2, ...]
argmax_dim: dimension to apply (default: 1)

Return:
A tensor with categorical labels [N, d2, ...]
"""
return torch.argmax(tensor, dim=argmax_dim)


def get_num_classes(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int]) -> int:
""" Returns the number of classes for a given prediction and
target tensor
"""
Returns the number of classes for a given prediction and target tensor.

Args:
pred: predicted values
target: true labels
num_classes: number of classes if known (default: None)

Return:
An integer that represents the number of classes.
"""
if num_classes is None:
if pred.ndim > target.ndim:
Expand All @@ -50,8 +68,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
class_index: int, argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calculates the number of true postive, false postive, true negative
and false negative for a specfic class
"""
Calculates the number of true positive, falsepositivee, true negative
and false negative for a specific class

Args:
pred: prediction tensor
Expand All @@ -63,6 +82,9 @@ def stat_scores(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over

Return:
Tensors in the following order: True Positive, False Positive, True Negative, False Negative

"""
if pred.ndim == target.ndim + 1:
pred = to_categorical(pred, argmax_dim=argmax_dim)
Expand All @@ -80,20 +102,21 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
argmax_dim: int = 1
) -> Tuple[torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
""" Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class
"""
Calls the stat_scores function iteratively for all classes, thus
calculating the number of true postive, false postive, true negative
and false negative for each class

Args:
pred: prediction tensor

target: target tensor

class_index: class to calculate over

argmax_dim: if pred is a tensor of probabilities, this indicates the
axis the argmax transformation will be applied over

Return:
Returns tensors for: tp, fp, tn, fn

"""
num_classes = get_num_classes(pred=pred, target=target,
num_classes=num_classes)
Expand All @@ -116,6 +139,23 @@ def stat_scores_multiple_classes(pred: torch.Tensor, target: torch.Tensor,
def accuracy(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean') -> torch.Tensor:
"""
Computes the accuracy classification score

Args:
pred: predicted labels
target: ground truth labels
num_classes: number of classes
reduction: a method for reducing accuracies over labels (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
A Tensor with the classification score.
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred, target=target,
num_classes=num_classes)

Expand All @@ -129,6 +169,18 @@ def accuracy(pred: torch.Tensor, target: torch.Tensor,

def confusion_matrix(pred: torch.Tensor, target: torch.Tensor,
normalize: bool = False) -> torch.Tensor:
"""
Computes the confusion matrix C where each entry C_{i,j} is the number of observations
in group i that were predicted in group j.

Args:
pred: estimated targets
target: ground truth labels
normalize: normalizes confusion matrix

Return:
Tensor, confusion matrix C [num_classes, num_classes ]
"""
num_classes = get_num_classes(pred, target, None)

d = target.size(-1)
Expand All @@ -149,6 +201,23 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean'
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Computes precision and recall for different thresholds

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision-recall values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with precision and recall
"""
tps, fps, tns, fns = stat_scores_multiple_classes(pred=pred,
target=target,
num_classes=num_classes)
Expand All @@ -168,20 +237,77 @@ def precision_recall(pred: torch.Tensor, target: torch.Tensor,
def precision(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes precision score.

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing precision values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with precision.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[0]


def recall(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes recall score.

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing recall values (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements

Return:
Tensor with recall.
"""
return precision_recall(pred=pred, target=target,
num_classes=num_classes, reduction=reduction)[1]


def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
num_classes: Optional[int] = None,
reduction: str = 'elementwise_mean') -> torch.Tensor:
"""
Computes the F-beta score which is a weighted harmonic mean of precision and recall.
It ranges between 1 and 0, where 1 is perfect and the worst value is 0.

Args:
pred: estimated probabilities
target: ground-truth labels
beta: weights recall when combining the score.
beta < 1: more weight to precision.
beta > 1 more weight to recall
beta = 0: only precision
beta -> inf: only recall
num_classes: number of classes
reduction: method for reducing F-score (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.

Return:
Tensor with the value of F-score. It is a value between 0-1.
"""
prec, rec = precision_recall(pred=pred, target=target,
num_classes=num_classes,
reduction='none')
Expand All @@ -196,6 +322,23 @@ def fbeta_score(pred: torch.Tensor, target: torch.Tensor, beta: float,
def f1_score(pred: torch.Tensor, target: torch.Tensor,
num_classes: Optional[int] = None,
reduction='elementwise_mean') -> torch.Tensor:
"""
Computes F1-score a.k.a F-measure.

Args:
pred: estimated probabilities
target: ground-truth labels
num_classes: number of classes
reduction: method for reducing F1-score (default: takes the mean)
Available reduction methods:

- elementwise_mean: takes the mean
- none: pass array
- sum: add elements.

Return:
Tensor containing F1-score
"""
return fbeta_score(pred=pred, target=target, beta=1.,
num_classes=num_classes, reduction=reduction)

Expand Down Expand Up @@ -251,6 +394,18 @@ def roc(pred: torch.Tensor, target: torch.Tensor,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Computes the Receiver Operating Characteristic (ROC). It assumes classifier is binary.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1)

Return:
[Tensor, Tensor, Tensor]: false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
Expand Down Expand Up @@ -282,6 +437,19 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor,
) -> Tuple[Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]]:
"""
Computes the Receiver Operating Characteristic (ROC) for multiclass predictors.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
num_classes: number of classes (default: None, computes automatically from data)

Return:
[num_classes, Tensor, Tensor, Tensor]: returns roc for each class.
number of classes, false-positive rate (fpr), true-positive rate (tpr), thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)

class_roc_vals = []
Expand All @@ -301,6 +469,18 @@ def precision_recall_curve(pred: torch.Tensor,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
"""
Computes precision-recall pairs for different thresholds.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)

Return:
[Tensor, Tensor, Tensor]: precision, recall, thresholds
"""
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
Expand Down Expand Up @@ -334,6 +514,18 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
) -> Tuple[Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]]:
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weight
num_classes: number of classes

Return:
[num_classes, Tensor, Tensor, Tensor]: number of classes, precision, recall, thresholds
"""
num_classes = get_num_classes(pred, target, num_classes)

class_pr_vals = []
Expand All @@ -350,6 +542,17 @@ def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,


def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
"""
Computes Area Under the Curve (AUC) using the trapezoidal rule

Args:
x: x-coordinates
y: y-coordinates
reorder: reorder coordinates, so they are increasing.

Return:
AUC score (float)
"""
direction = 1.

if reorder:
Expand Down Expand Up @@ -400,6 +603,15 @@ def new_func(*args, **kwargs) -> torch.Tensor:
def auroc(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> torch.Tensor:
"""
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC) from prediction scores

Args:
pred: estimated probabilities
target: ground-truth labels
sample_weight: sample weights
pos_label: the label for the positive class (default: 1.)
"""
return roc(pred=pred, target=target, sample_weight=sample_weight,
pos_label=pos_label)

Expand All @@ -410,7 +622,6 @@ def average_precision(pred: torch.Tensor, target: torch.Tensor,
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
Expand Down