diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 860a4a99dabd8a..57829431a28bfa 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -4,6 +4,7 @@ import torch from sklearn.metrics import ( accuracy_score as sk_accuracy, + jaccard_score as sk_jaccard_score, precision_score as sk_precision, recall_score as sk_recall, f1_score as sk_f1_score, @@ -37,6 +38,7 @@ @pytest.mark.parametrize(['sklearn_metric', 'torch_metric'], [ pytest.param(sk_accuracy, accuracy, id='accuracy'), + pytest.param(partial(sk_jaccard_score, average='macro'), iou, id='iou'), pytest.param(partial(sk_precision, average='macro'), precision, id='precision'), pytest.param(partial(sk_recall, average='macro'), recall, id='recall'), pytest.param(partial(sk_f1_score, average='macro'), f1_score, id='f1_score'), @@ -346,6 +348,9 @@ def test_iou(half_ones, reduction, remove_bg, expected): assert torch.allclose(iou_val, expected, atol=1e-9) +# TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see +# https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our +# `not_present_score`. @pytest.mark.parametrize(['pred', 'target', 'not_present_score', 'num_classes', 'remove_bg', 'expected'], [ # Note that -1 is used as the not_present_score in almost all tests here to distinguish it from the range of valid # scores the function can return ([0., 1.] range, inclusive).