Skip to content

Commit

Permalink
Test IoU against sklearn jaccard_score
Browse files Browse the repository at this point in the history
Also add TODO to test our IoU's not_present_score against sklearn's
jaccard_score's zero_division when it beecomes available.
  • Loading branch information
abrahambotros committed Sep 8, 2020
1 parent 172a7c1 commit c0737f6
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from sklearn.metrics import (
accuracy_score as sk_accuracy,
jaccard_score as sk_jaccard_score,
precision_score as sk_precision,
recall_score as sk_recall,
f1_score as sk_f1_score,
Expand Down Expand Up @@ -37,6 +38,7 @@

@pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [
pytest.param(sk_accuracy, accuracy, id='accuracy'),
pytest.param(partial(sk_jaccard_score, average='macro'), iou, id='iou'),
pytest.param(partial(sk_precision, average='macro'), precision, id='precision'),
pytest.param(partial(sk_recall, average='macro'), recall, id='recall'),
pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'),
Expand Down Expand Up @@ -346,6 +348,9 @@ def test_iou(half_ones, reduction, remove_bg, expected):
assert torch.allclose(iou_val, expected, atol=1e-9)


# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see
# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our
# `not_present_score`.
@pytest.mark.parametrize(['pred', 'target', 'not_present_score', 'num_classes', 'remove_bg', 'expected'], [
# Note that -1 is used as the not_present_score in almost all tests here to distinguish it from the range of valid
# scores the function can return ([0., 1.] range, inclusive).
Expand Down

0 comments on commit c0737f6

Please sign in to comment.