Skip to content

Commit

Permalink
fix roc score metric
Browse files Browse the repository at this point in the history
  • Loading branch information
cuent authored and justusschock committed May 28, 2020
1 parent 9513e3d commit aebe9ec
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
10 changes: 5 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ def roc(pred: torch.Tensor, target: torch.Tensor,
pos_label: int = 1.) -> Tuple[torch.Tensor,
torch.Tensor,
torch.Tensor]:
tps, fps, thresholds = _binary_clf_curve(pred=pred, target=target,
fps, tps, thresholds = _binary_clf_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Add an extra threshold position
# to make sure that the curve starts at (0, 0)
tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
thresholds = torch.cat([thresholds[0][None], thresholds])
thresholds = torch.cat([thresholds[0][None] + 1, thresholds])

if fps[-1] <= 0:
raise ValueError("No negative samples in targets, "
Expand Down Expand Up @@ -267,10 +267,10 @@ def precision_recall_curve(pred: torch.Tensor,
# stop when full recall attained
# and reverse the outputs so recall is decreasing
last_ind = torch.where(tps == tps[-1])[0][0]
sl = slice(0, last_ind.item())
sl = slice(0, last_ind.item() + 1)

# need to call reversed explicitly, since including that to slice would
# introduce negative strides thet are not yet supported in pytorch
# introduce negative strides that are not yet supported in pytorch
precision = torch.cat([reversed(precision[sl]),
torch.ones(1, dtype=precision.dtype,
device=precision.device)])
Expand Down Expand Up @@ -315,7 +315,7 @@ def auc(x: torch.Tensor, y: torch.Tensor, reorder: bool = True):
else:
dx = x[1:] - x[:-1]
if (dx < 0).any():
if (dx , 0).all():
if (dx, 0).all():
direction = -1.
else:
raise ValueError("Reordering is not turned on, and "
Expand Down
3 changes: 1 addition & 2 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ def test_roc_curve(pred, target, expected_tpr, expected_fpr):
pytest.param(torch.tensor([1, 0]), torch.tensor([1, 0]), 1.),
pytest.param(torch.tensor([0.5, 0.5]), torch.tensor([1, 0]), 0.5)
])
# TODO: FIx
def test_auroc(pred, target, expected):
score = auroc(pred, target).item()
assert score == expected
Expand All @@ -290,4 +289,4 @@ def test_dice_score():
dice_score()

# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py

0 comments on commit aebe9ec

Please sign in to comment.