diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 8b5b1b6a18f84..8fed3ddf5b9f0 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -194,7 +194,7 @@ def _binary_clf_curve(pred: torch.Tensor, target: torch.Tensor, if sample_weight is not None: # express fps as a cumsum to ensure fps is increasing even in # the presence of floating point errors - fps = torch.cumsum((1 - target) * weight)[threshold_idxs] + fps = torch.cumsum((1 - target) * weight, dim=0)[threshold_idxs] else: fps = 1 + threshold_idxs - tps @@ -359,13 +359,17 @@ def auroc(pred: torch.Tensor, target: torch.Tensor, pos_label=pos_label) -@auc_decorator(reorder=True) def average_precision(pred: torch.Tensor, target: torch.Tensor, sample_weight: Optional[Sequence] = None, pos_label: int = 1.) -> torch.Tensor: - return precision_recall_curve(pred=pred, target=target, - sample_weight=sample_weight, - pos_label=pos_label) + precision, recall, _ = precision_recall_curve(pred=pred, target=target, + sample_weight=sample_weight, + pos_label=pos_label) + + # Return the step function integral + # The following works because the last entry of precision is + # guaranteed to be 1, as returned by precision_recall_curve + return -torch.sum(recall[1:] - recall[:-1] * precision[:-1]) def dice_score(pred: torch.Tensor, target: torch.Tensor, bg: bool = False, diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 48345cd11cca6..b104159e8bfe2 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -6,6 +6,11 @@ f1_score, _binary_clf_curve, dice_score, average_precision, auroc, precision_recall_curve, roc +@pytest.fixture +def random(): + torch.manual_seed(0) + + def test_onehot(): test_tensor = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) expected = torch.tensor([ @@ -198,32 +203,32 @@ def test_f1_score(pred, target, exp_score): @pytest.mark.parametrize( - ['pred', 'target', 'sample_weight', 'pos_label', ], [ + ['pred', 'target', 'sample_weight', 'pos_label', "exp_shape"], [ pytest.param(torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100, torch.tensor([int(bool(idx % 2)) for idx in range(100)]), - 1., 1. + torch.ones(100), 1., 40 ), pytest.param(torch.randint(low=51, high=99, size=(100,), dtype=torch.float) / 100, torch.tensor([int(bool(idx % 2)) for idx in range(100)]), - None, 1. + None, 1., 39 ) ] ) -def test_binary_clf_curve(pred, target, sample_weight, pos_label): +@pytest.mark.usefixtures("random") +def test_binary_clf_curve(pred, target, sample_weight, pos_label, exp_shape): fps, tps, thresh = _binary_clf_curve(pred, target, sample_weight, pos_label) assert isinstance(tps, torch.Tensor) assert isinstance(fps, torch.Tensor) assert isinstance(thresh, torch.Tensor) - assert tps.shape == (40,) - assert fps.shape == (40,) - assert thresh.shape == (40,) + assert tps.shape == (exp_shape,) + assert fps.shape == (exp_shape,) + assert thresh.shape == (exp_shape,) @pytest.mark.parametrize(['pred', 'target', 'expected_p', 'expected_r', 'expected_t'], [ pytest.param(torch.tensor([1, 2, 3, 4]), torch.tensor([1, 0, 0, 1]), - torch.tensor([1 / 3, 0.5, 1., 1.]), torch.tensor([0.5, 0.5, 0.5, 0.]), - torch.tensor([2, 3, 4])) - + torch.tensor([0.5, 1 / 3, 0.5, 1., 1.]), torch.tensor([1, 0.5, 0.5, 0.5, 0.]), + torch.tensor([1, 2, 3, 4])) ]) def test_pr_curve(pred, target, expected_p, expected_r, expected_t): p, r, t = precision_recall_curve(pred, target) @@ -236,15 +241,15 @@ def test_pr_curve(pred, target, expected_p, expected_r, expected_t): @pytest.mark.parametrize(['pred', 'target', 'expected_tpr', 'expected_fpr'], [ - pytest.param(torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 0, 1]), - torch.tensor([0, 1, 1])), - pytest.param(torch.tensor([1, 0]), torch.tensor([0, 1]), torch.tensor([0, 1, 1]), + pytest.param(torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 1, 1]), torch.tensor([0, 0, 1])), + pytest.param(torch.tensor([1, 0]), torch.tensor([0, 1]), torch.tensor([0, 0, 1]), + torch.tensor([0, 1, 1])), pytest.param(torch.tensor([1, 1]), torch.tensor([1, 0]), torch.tensor([0, 1]), torch.tensor([0, 1])), - pytest.param(torch.tensor([1, 0]), torch.tensor([1, 0]), torch.tensor([0, 0, 1]), - torch.tensor([0, 1, 1])), - pytest.param(torch.tensor([0.5, 0.5]), torch.tensor([1, 0]), torch.tensor([0, 1]), + pytest.param(torch.tensor([1, 0]), torch.tensor([1, 0]), torch.tensor([0, 1, 1]), + torch.tensor([0, 0, 1])), + pytest.param(torch.tensor([0.5, 0.5]), torch.tensor([0, 1]), torch.tensor([0, 1]), torch.tensor([0, 1])) ]) def test_roc_curve(pred, target, expected_tpr, expected_fpr): @@ -268,7 +273,6 @@ def test_auroc(pred, target, expected): assert score == expected -# TODO: Fix def test_average_precision_constant_values(): # Check the average_precision_score of a constant predictor is # the TPR