Skip to content

Commit

Permalink
fix non-adaptive thresholding bug (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
djdameln authored Mar 22, 2022
1 parent 834d45a commit b66e5e3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
10 changes: 6 additions & 4 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
self.model: nn.Module

# metrics
auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
f1_score = F1(num_classes=1, compute_on_step=False)
self.image_metrics = MetricCollection([auroc, f1_score], prefix="image_").cpu()
self.pixel_metrics = self.image_metrics.clone(prefix="pixel_").cpu()
image_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
image_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.image_default)
pixel_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
pixel_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.pixel_default)
self.image_metrics = MetricCollection([image_auroc, image_f1], prefix="image_").cpu()
self.pixel_metrics = MetricCollection([pixel_auroc, pixel_f1], prefix="pixel_").cpu()

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down
32 changes: 32 additions & 0 deletions tests/pre_merge/utils/metrics/test_adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import random

import pytest
import torch
from pytorch_lightning import Trainer

from anomalib.config import get_configurable_parameters
from anomalib.data import get_datamodule
from anomalib.models import get_model
from anomalib.utils.callbacks import get_callbacks
from anomalib.utils.metrics import AdaptiveThreshold


Expand All @@ -35,3 +41,29 @@ def test_adaptive_threshold(labels, preds, target_threshold):
threshold_value = adaptive_threshold.compute()

assert threshold_value == target_threshold


def test_non_adaptive_threshold():
"""
Test if the non-adaptive threshold gets used in the F1 score computation when
adaptive thresholding is disabled and no normalization is used.
"""
config = get_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")

config.model.normalization_method = "none"
config.model.threshold.adaptive = False
config.trainer.fast_dev_run = True

image_threshold = random.random()
pixel_threshold = random.random()
config.model.threshold.image_default = image_threshold
config.model.threshold.pixel_default = pixel_threshold

model = get_model(config)
datamodule = get_datamodule(config)
callbacks = get_callbacks(config)

trainer = Trainer(**config.trainer, callbacks=callbacks)
trainer.fit(model=model, datamodule=datamodule)
assert trainer.model.image_metrics.F1.threshold == image_threshold
assert trainer.model.pixel_metrics.F1.threshold == pixel_threshold

0 comments on commit b66e5e3

Please sign in to comment.