From 7007b124c083f6af66f9f86bef37252ada2ed0f3 Mon Sep 17 00:00:00 2001 From: Abe Botros Date: Thu, 3 Sep 2020 12:00:34 -0700 Subject: [PATCH] IoU: remove_bg -> ignore_index Fixes #2736 - Rename IoU metric argument from `remove_bg` -> `ignore_index`. - Accept an optional int class index to ignore, instead of a bool and instead of always assuming the background class has index 0. - If given, ignore the class index when computing the IoU output, regardless of reduction method. --- CHANGELOG.md | 2 + pytorch_lightning/metrics/classification.py | 14 ++-- .../metrics/functional/classification.py | 33 ++++---- .../metrics/functional/test_classification.py | 78 ++++++++++++------- tests/metrics/test_classification.py | 6 +- 5 files changed, 83 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11b18c77d7d566..d9d40bb6f90b76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed IoU score behavior for classes not present in target or pred ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) +- Changed IoU `remove_bg` bool to `ignore_index` optional int ([#3098](https://github.com/PyTorchLightning/pytorch-lightning/pull/3098)) + ### Deprecated diff --git a/pytorch_lightning/metrics/classification.py b/pytorch_lightning/metrics/classification.py index bc7c9b45e0eb6a..004df29f20e452 100644 --- a/pytorch_lightning/metrics/classification.py +++ b/pytorch_lightning/metrics/classification.py @@ -797,20 +797,20 @@ class IoU(TensorMetric): def __init__( self, + ignore_index: Optional[int] = None, not_present_score: float = 1.0, num_classes: Optional[int] = None, - remove_bg: bool = False, reduction: str = 'elementwise_mean' ): """ Args: + ignore_index: optional int specifying a target class to ignore. If given, this class index does not + contribute to the returned score, regardless of reduction method. Has no effect if given an int that is + not in the range [0, num_classes-1], where num_classes is either given or derived from pred and target. + By default, no index is ignored, and all classes are used. not_present_score: score to use for a class, if no instance of that class was present in either pred or target num_classes: Optionally specify the number of classes - remove_bg: Flag to state whether a background class has been included - within input parameters. If true, will remove background class. If - false, return IoU over all classes. - Assumes that background is '0' class in input tensor reduction: a method to reduce metric score over labels (default: takes the mean) Available reduction methods: @@ -819,9 +819,9 @@ def __init__( - sum: add elements """ super().__init__(name='iou') + self.ignore_index = ignore_index self.not_present_score = not_present_score self.num_classes = num_classes - self.remove_bg = remove_bg self.reduction = reduction def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, @@ -832,8 +832,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor, return iou( pred=y_pred, target=y_true, + ignore_index=self.ignore_index, not_present_score=self.not_present_score, num_classes=self.num_classes, - remove_bg=self.remove_bg, reduction=self.reduction, ) diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index 6463f4ad9999ca..19b3929e498cd9 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -963,9 +963,9 @@ def dice_score( def iou( pred: torch.Tensor, target: torch.Tensor, + ignore_index: Optional[int] = None, not_present_score: float = 1.0, num_classes: Optional[int] = None, - remove_bg: bool = False, reduction: str = 'elementwise_mean', ) -> torch.Tensor: """ @@ -974,12 +974,12 @@ def iou( Args: pred: Tensor containing predictions target: Tensor containing targets + ignore_index: optional int specifying a target class to ignore. If given, this class index does not contribute + to the returned score, regardless of reduction method. Has no effect if given an int that is not in the + range [0, num_classes-1], where num_classes is either given or derived from pred and target. By default, no + index is ignored, and all classes are used. not_present_score: score to use for a class, if no instance of that class was present in either pred or target num_classes: Optionally specify the number of classes - remove_bg: Flag to state whether a background class has been included - within input parameters. If true, will remove background class. If - false, return IoU over all classes - Assumes that background is '0' class in input tensor reduction: a method to reduce metric score over labels (default: takes the mean) Available reduction methods: @@ -1002,15 +1002,15 @@ def iou( """ num_classes = get_num_classes(pred=pred, target=target, num_classes=num_classes) - # Determine minimum class index we will be evaluating. If using the background, then this is 0; otherwise, if - # removing background, use 1. - min_class_idx = 1 if remove_bg else 0 - tps, fps, tns, fns, sups = stat_scores_multiple_classes(pred, target, num_classes) - scores = torch.zeros(num_classes - min_class_idx, device=pred.device, dtype=torch.float32) + scores = torch.zeros(num_classes, device=pred.device, dtype=torch.float32) + + for class_idx in range(num_classes): + # Skip this class if its index is being ignored. + if class_idx == ignore_index: + continue - for class_idx in range(min_class_idx, num_classes): tp = tps[class_idx] fp = fps[class_idx] fn = fns[class_idx] @@ -1019,11 +1019,18 @@ def iou( # If this class is not present in either the target (no support) or the pred (no true or false positives), then # use the not_present_score for this class. if sup + tp + fp == 0: - scores[class_idx - min_class_idx] = not_present_score + scores[class_idx] = not_present_score continue denom = tp + fp + fn score = tp.to(torch.float) / denom - scores[class_idx - min_class_idx] = score + scores[class_idx] = score + + # Remove the ignored class index from the scores. + if ignore_index is not None and ignore_index >= 0 and ignore_index < num_classes: + scores = torch.cat([ + scores[:ignore_index], + scores[ignore_index + 1:], + ]) return reduce(scores, reduction=reduction) diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index 57829431a28bfa..5ae7655ea9c185 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -326,15 +326,15 @@ def test_dice_score(pred, target, expected): assert score == expected -@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [ - pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])), - pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])), - pytest.param(False, 'none', True, torch.Tensor([1, 1])), - pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])), - pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])), - pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])), +@pytest.mark.parametrize(['half_ones', 'reduction', 'ignore_index', 'expected'], [ + pytest.param(False, 'none', None, torch.Tensor([1, 1, 1])), + pytest.param(False, 'elementwise_mean', None, torch.Tensor([1])), + pytest.param(False, 'none', 0, torch.Tensor([1, 1])), + pytest.param(True, 'none', None, torch.Tensor([0.5, 0.5, 0.5])), + pytest.param(True, 'elementwise_mean', None, torch.Tensor([0.5])), + pytest.param(True, 'none', 0, torch.Tensor([0.5, 0.5])), ]) -def test_iou(half_ones, reduction, remove_bg, expected): +def test_iou(half_ones, reduction, ignore_index, expected): pred = (torch.arange(120) % 3).view(-1, 1) target = (torch.arange(120) % 3).view(-1, 1) if half_ones: @@ -342,7 +342,7 @@ def test_iou(half_ones, reduction, remove_bg, expected): iou_val = iou( pred=pred, target=target, - remove_bg=remove_bg, + ignore_index=ignore_index, reduction=reduction, ) assert torch.allclose(iou_val, expected, atol=1e-9) @@ -351,46 +351,70 @@ def test_iou(half_ones, reduction, remove_bg, expected): # TODO: When the jaccard_score of the sklearn version we use accepts `zero_division` (see # https://github.com/scikit-learn/scikit-learn/pull/17866), consider adding a test here against our # `not_present_score`. -@pytest.mark.parametrize(['pred', 'target', 'not_present_score', 'num_classes', 'remove_bg', 'expected'], [ +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'not_present_score', 'num_classes', 'expected'], [ # Note that -1 is used as the not_present_score in almost all tests here to distinguish it from the range of valid # scores the function can return ([0., 1.] range, inclusive). # 2 classes, class 0 is correct everywhere, class 1 is not present. - pytest.param([0], [0], -1., 2, False, [1., -1.]), - pytest.param([0, 0], [0, 0], -1., 2, False, [1., -1.]), + pytest.param([0], [0], None, -1., 2, [1., -1.]), + pytest.param([0, 0], [0, 0], None, -1., 2, [1., -1.]), # not_present_score not applied if only class 0 is present and it's the only class. - pytest.param([0], [0], -1., 1, False, [1.]), + pytest.param([0], [0], None, -1., 1, [1.]), # 2 classes, class 1 is correct everywhere, class 0 is not present. - pytest.param([1], [1], -1., 2, False, [-1., 1.]), - pytest.param([1, 1], [1, 1], -1., 2, False, [-1., 1.]), - # When background removed, class 0 does not get a score (not even the not_present_score). - pytest.param([1], [1], -1., 2, True, [1.0]), + pytest.param([1], [1], None, -1., 2, [-1., 1.]), + pytest.param([1, 1], [1, 1], None, -1., 2, [-1., 1.]), + # When 0 index ignored, class 0 does not get a score (not even the not_present_score). + pytest.param([1], [1], 0, -1., 2, [1.0]), # 3 classes. Only 0 and 2 are present, and are perfectly predicted. 1 should get not_present_score. - pytest.param([0, 2], [0, 2], -1., 3, False, [1., -1., 1.]), - pytest.param([2, 0], [2, 0], -1., 3, False, [1., -1., 1.]), + pytest.param([0, 2], [0, 2], None, -1., 3, [1., -1., 1.]), + pytest.param([2, 0], [2, 0], None, -1., 3, [1., -1., 1.]), # 3 classes. Only 0 and 1 are present, and are perfectly predicted. 2 should get not_present_score. - pytest.param([0, 1], [0, 1], -1., 3, False, [1., 1., -1.]), - pytest.param([1, 0], [1, 0], -1., 3, False, [1., 1., -1.]), + pytest.param([0, 1], [0, 1], None, -1., 3, [1., 1., -1.]), + pytest.param([1, 0], [1, 0], None, -1., 3, [1., 1., -1.]), # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in pred but not target; should not get not_present_score), class # 2 is not present. - pytest.param([0, 1], [0, 0], -1., 3, False, [0.5, 0., -1.]), + pytest.param([0, 1], [0, 0], None, -1., 3, [0.5, 0., -1.]), # 3 classes, class 0 is 0.5 IoU, class 1 is 0 IoU (in target but not pred; should not get not_present_score), class # 2 is not present. - pytest.param([0, 0], [0, 1], -1., 3, False, [0.5, 0., -1.]), + pytest.param([0, 0], [0, 1], None, -1., 3, [0.5, 0., -1.]), # Sanity checks with not_present_score of 1.0. - pytest.param([0, 2], [0, 2], 1.0, 3, False, [1., 1., 1.]), - pytest.param([0, 2], [0, 2], 1.0, 3, True, [1., 1.]), + pytest.param([0, 2], [0, 2], None, 1.0, 3, [1., 1., 1.]), + pytest.param([0, 2], [0, 2], 0, 1.0, 3, [1., 1.]), ]) -def test_iou_not_present_score(pred, target, not_present_score, num_classes, remove_bg, expected): +def test_iou_not_present_score(pred, target, ignore_index, not_present_score, num_classes, expected): iou_val = iou( pred=torch.tensor(pred), target=torch.tensor(target), + ignore_index=ignore_index, not_present_score=not_present_score, num_classes=num_classes, - remove_bg=remove_bg, reduction='none', ) assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) +@pytest.mark.parametrize(['pred', 'target', 'ignore_index', 'num_classes', 'reduction', 'expected'], [ + # Ignoring an index outside of [0, num_classes-1] should have no effect. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], None, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], -1, 3, 'none', [1, 1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 255, 3, 'none', [1, 1 / 2, 2 / 3]), + # Ignoring a valid index drops only that index from the result. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'none', [1 / 2, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 1, 3, 'none', [1, 2 / 3]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 2, 3, 'none', [1, 1 / 2]), + # When reducing to mean or sum, the ignored index does not contribute to the output. + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'elementwise_mean', [7 / 12]), + pytest.param([0, 1, 1, 2, 2], [0, 1, 2, 2, 2], 0, 3, 'sum', [7 / 6]), +]) +def test_iou_ignore_index(pred, target, ignore_index, num_classes, reduction, expected): + iou_val = iou( + pred=torch.tensor(pred), + target=torch.tensor(target), + ignore_index=ignore_index, + num_classes=num_classes, + reduction=reduction, + ) + assert torch.allclose(iou_val, torch.tensor(expected).to(iou_val)) + + # example data taken from # https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py diff --git a/tests/metrics/test_classification.py b/tests/metrics/test_classification.py index c5a2f8915c5842..7137c6f882cf02 100644 --- a/tests/metrics/test_classification.py +++ b/tests/metrics/test_classification.py @@ -207,9 +207,9 @@ def test_dice_coefficient(include_background): assert isinstance(dice, torch.Tensor) -@pytest.mark.parametrize('remove_bg', [True, False]) -def test_iou(remove_bg): - iou = IoU(remove_bg=remove_bg) +@pytest.mark.parametrize('ignore_index', [0, 1, None]) +def test_iou(ignore_index): + iou = IoU(ignore_index=ignore_index) assert iou.name == 'iou' score = iou(torch.randint(0, 1, (10, 25, 25)),