Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cuent authored and justusschock committed May 28, 2020
1 parent aebe9ec commit eac982a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 22 deletions.
14 changes: 9 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _binary_clf_curve(pred: torch.Tensor, target: torch.Tensor,
if sample_weight is not None:
# express fps as a cumsum to ensure fps is increasing even in
# the presence of floating point errors
fps = torch.cumsum((1 - target) * weight)[threshold_idxs]
fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs]

else:
fps = 1 + threshold_idxs - tps
Expand Down Expand Up @@ -359,13 +359,17 @@ def auroc(pred: torch.Tensor, target: torch.Tensor,
pos_label=pos_label)


@auc_decorator(reorder=True)
def average_precision(pred: torch.Tensor, target: torch.Tensor,
sample_weight: Optional[Sequence] = None,
pos_label: int = 1.) -> torch.Tensor:
return precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)
precision, recall, _ = precision_recall_curve(pred=pred, target=target,
sample_weight=sample_weight,
pos_label=pos_label)

# Return the step function integral
# The following works because the last entry of precision is
# guaranteed to be 1, as returned by precision_recall_curve
return -torch.sum(recall[1:] - recall[:-1] * precision[:-1])


def dice_score(pred: torch.Tensor, target: torch.Tensor, bg: bool = False,
Expand Down
38 changes: 21 additions & 17 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
f1_score, _binary_clf_curve, dice_score, average_precision, auroc, precision_recall_curve, roc


@pytest.fixture
def random():
torch.manual_seed(0)


def test_onehot():
test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
expected = torch.tensor([
Expand Down Expand Up @@ -198,32 +203,32 @@ def test_f1_score(pred, target, exp_score):


@pytest.mark.parametrize(
['pred', 'target', 'sample_weight', 'pos_label', ], [
['pred', 'target', 'sample_weight', 'pos_label', "exp_shape"], [
pytest.param(torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100,
torch.tensor([int(bool(idx % 2)) for idx in range(100)]),
1., 1.
torch.ones(100), 1., 40
),
pytest.param(torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100,
torch.tensor([int(bool(idx % 2)) for idx in range(100)]),
None, 1.
None, 1., 39
)
]
)
def test_binary_clf_curve(pred, target, sample_weight, pos_label):
@pytest.mark.usefixtures("random")
def test_binary_clf_curve(pred, target, sample_weight, pos_label, exp_shape):
fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label)
assert isinstance(tps, torch.Tensor)
assert isinstance(fps, torch.Tensor)
assert isinstance(thresh, torch.Tensor)
assert tps.shape == (40,)
assert fps.shape == (40,)
assert thresh.shape == (40,)
assert tps.shape == (exp_shape,)
assert fps.shape == (exp_shape,)
assert thresh.shape == (exp_shape,)


@pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [
pytest.param(torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]),
torch.tensor([1 / 3, 0.5, 1., 1.]), torch.tensor([0.5, 0.5, 0.5, 0.]),
torch.tensor([2, 3, 4]))
torch.tensor([0.5, 1 / 3, 0.5, 1., 1.]), torch.tensor([1, 0.5, 0.5, 0.5, 0.]),
torch.tensor([1, 2, 3, 4]))
])
def test_pr_curve(pred, target, expected_p, expected_r, expected_t):
p, r, t = precision_recall_curve(pred, target)
Expand All @@ -236,15 +241,15 @@ def test_pr_curve(pred, target, expected_p, expected_r, expected_t):


@pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [
pytest.param(torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 0, 1]),
torch.tensor([0, 1, 1])),
pytest.param(torch.tensor([1, 0]), torch.tensor([0, 1]), torch.tensor([0, 1, 1]),
pytest.param(torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 1, 1]),
torch.tensor([0, 0, 1])),
pytest.param(torch.tensor([1, 0]), torch.tensor([0, 1]), torch.tensor([0, 0, 1]),
torch.tensor([0, 1, 1])),
pytest.param(torch.tensor([1, 1]), torch.tensor([1, 0]), torch.tensor([0, 1]),
torch.tensor([0, 1])),
pytest.param(torch.tensor([1, 0]), torch.tensor([1, 0]), torch.tensor([0, 0, 1]),
torch.tensor([0, 1, 1])),
pytest.param(torch.tensor([0.5, 0.5]), torch.tensor([1, 0]), torch.tensor([0, 1]),
pytest.param(torch.tensor([1, 0]), torch.tensor([1, 0]), torch.tensor([0, 1, 1]),
torch.tensor([0, 0, 1])),
pytest.param(torch.tensor([0.5, 0.5]), torch.tensor([0, 1]), torch.tensor([0, 1]),
torch.tensor([0, 1]))
])
def test_roc_curve(pred, target, expected_tpr, expected_fpr):
Expand All @@ -268,7 +273,6 @@ def test_auroc(pred, target, expected):
assert score == expected


# TODO: Fix
def test_average_precision_constant_values():
# Check the average_precision_score of a constant predictor is
# the TPR
Expand Down

0 comments on commit eac982a

Please sign in to comment.