Skip to content

Commit

Permalink
Fix IndexError in adaptive threshold computation (#146)
Browse files Browse the repository at this point in the history
* address scalar thresholds issue

* add tests for adaptive threshold metric
  • Loading branch information
djdameln committed Mar 14, 2022
1 parent cebc3a4 commit 5f3ee27
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
7 changes: 6 additions & 1 deletion anomalib/utils/metrics/adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,10 @@ def compute(self) -> torch.Tensor:

precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.value = thresholds[torch.argmax(f1_score)]
if thresholds.dim() == 0:
# special case where recall is 1.0 even for the highest threshold.
# In this case 'thresholds' will be scalar.
self.value = thresholds
else:
self.value = thresholds[torch.argmax(f1_score)]
return self.value
13 changes: 13 additions & 0 deletions tests/pre_merge/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
37 changes: 37 additions & 0 deletions tests/pre_merge/utils/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Tests for the adaptive threshold metric."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import pytest
import torch

from anomalib.utils.metrics import AdaptiveThreshold


@pytest.mark.parametrize(
["labels", "preds", "target_threshold"],
[
(torch.Tensor([0, 0, 0, 1, 1]), torch.Tensor([2.3, 1.6, 2.6, 7.9, 3.3]), 3.3), # standard case
(torch.Tensor([1, 0, 0, 0]), torch.Tensor([4, 3, 2, 1]), 4), # 100% recall for all thresholds
],
)
def test_adaptive_threshold(labels, preds, target_threshold):
"""Test if the adaptive threshold computation returns the desired value."""

adaptive_threshold = AdaptiveThreshold(default_value=0.5)
adaptive_threshold.update(preds, labels)
threshold_value = adaptive_threshold.compute()

assert threshold_value == target_threshold

0 comments on commit 5f3ee27

Please sign in to comment.