Skip to content

Commit

Permalink
add more metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored and Borda committed Apr 17, 2020
1 parent 46f064b commit 6908e03
Showing 1 changed file with 135 additions and 16 deletions.
151 changes: 135 additions & 16 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import Sequence
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable

import torch

Expand Down Expand Up @@ -152,7 +152,6 @@ def f1_score(pred: torch.Tensor, target: torch.Tensor,
num_classes=num_classes, reduction=reduction)


# adapted from https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/_ranking.py
def _binary_clf_curve(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> Tuple[torch.Tensor,
Expand Down Expand Up @@ -247,25 +246,145 @@ def multiclass_roc(pred: torch.Tensor, target: torch.Tensor,
return tuple(class_roc_vals)


# TODO:
def precision_recall_curve():
pass
def precision_recall_curve(pred: torch.Tensor,
target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

precision = tps / (tps + fps)
recall = tps / tps[-1]

# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item())

# need to call reversed explicitly, since including that to slice would
# introduce negative strides thet are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]),
torch.ones(1, dtype=precision.dtype,
device=precision.device)])

recall = torch.cat([reversed(recall[sl]),
torch.zeros(1, dtype=recall.dtype,
device=recall.device)])

thresholds = reversed(thresholds[sl])

return precision, recall, thresholds


def multiclass_precision_recall_curve(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
num_classes: Optional[int] = None,
) -> Tuple[Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]]:
num_classes = get_num_classes(pred, target, num_classes)

class_pr_vals = []

for c in range(num_classes):
pred_c = pred[:, c]

class_pr_vals.append(precision_recall_curve(
pred=pred_c,
target=target,
sample_weight=sample_weight, pos_label=c))

# TODO:
def multilabel_precision_recall_curve():
pass
return tuple(class_pr_vals)


# TODO:
def auc():
pass
def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
direction = 1.

if reorder:
# can't use lexsort here since it is not implemented for torch
order = torch.argsort(x)
x, y = x[order], y[order]
else:
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx < 0).all():
direction = -1.
else:
raise ValueError("Reordering is not turned on, and "
"the x array is not increasing: %s" % x)

return direction * torch.trapz(y, x)


def auc_decorator(reorder: bool = False) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
def new_func(*args, **kwargs) -> torch.Tensor:
x, y = func_to_decorate(*args, **kwargs)[:2]

return auc(x, y, reorder=reorder)

return new_func

return wrapper


# TODO:
def auroc():
pass
def multiclass_auc_decorator(reorder: bool = False) -> Callable:
def wrapper(func_to_decorate: Callable) -> Callable:
def new_func(*args, **kwargs) -> torch.Tensor:
results = []
for class_result in func_to_decorate(*args, **kwargs):
x, y = class_result[:2]
results.append(auc(x, y, reorder=reorder))

return torch.cat(results)

return new_func

return wrapper


@auc_decorator(reorder=False)
def auroc(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> torch.Tensor:
return roc(pred=pred, target=target, sample_weight=sample_weight,
pos_label=pos_label)


@auc_decorator(reorder=False)
def average_precision(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> torch.Tensor:
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)


def dice_score(pred: torch.Tensor, target: torch.Tensor, bg: bool = False,
nan_score: float = 0.0, no_fg_score: float = 0.0,
reduction: str = 'elementwise_mean'):
n_classes = pred.shape[1]
bg = (1 - int(bool(bg)))
scores = torch.zeros(n_classes - bg, device=pred.device, dtype=pred.dtype)
for i in range(bg, n_classes):
if not (target == i).any():
# no foreground class
scores[i] += no_fg_score
continue

tp, fp, tn, fn = stat_scores(pred=pred[:, i], target=target,
class_index=i)

denom = (2 * tp + fp + fn).to(torch.float)

if torch.isclose(denom, torch.zeros_like(denom)).any():
# nan result
score_cls = nan_score
else:
score_cls = (2 * tp).to(torch.float) / denom

def dice_coefficient():
pass
scores[i] += score_cls
return reduce(scores, reduction=reduction)

0 comments on commit 6908e03

Please sign in to comment.