Skip to content

Commit

Permalink
🧹 Reset adaptive threshold between epochs (#527)
Browse files Browse the repository at this point in the history
refactor adaptive threshold and add call to reset method
  • Loading branch information
djdameln committed Aug 31, 2022
1 parent 702acc1 commit a03e592
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
12 changes: 8 additions & 4 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def test_epoch_end(self, outputs):
self._log_metrics()

def _compute_adaptive_threshold(self, outputs):
self.image_threshold.reset()
self.pixel_threshold.reset()
self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs)
self.image_threshold.compute()
if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys():
Expand All @@ -134,23 +136,25 @@ def _compute_adaptive_threshold(self, outputs):
self.image_metrics.set_threshold(self.image_threshold.value.item())
self.pixel_metrics.set_threshold(self.pixel_threshold.value.item())

def _collect_outputs(self, image_metric, pixel_metric, outputs):
@staticmethod
def _collect_outputs(image_metric, pixel_metric, outputs):
for output in outputs:
image_metric.cpu()
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"], output["mask"].int())

def _post_process(self, outputs):
@staticmethod
def _post_process(outputs):
"""Compute labels based on model predictions."""
if "pred_scores" not in outputs and "anomaly_maps" in outputs:
outputs["pred_scores"] = (
outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values
)

def _outputs_to_cpu(self, output):
# for output in outputs:
@staticmethod
def _outputs_to_cpu(output):
for key, value in output.items():
if isinstance(value, Tensor):
output[key] = value.cpu()
Expand Down
18 changes: 4 additions & 14 deletions anomalib/utils/metrics/adaptive_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,22 @@
# SPDX-License-Identifier: Apache-2.0

import torch
from torchmetrics import Metric, PrecisionRecallCurve
from torchmetrics import PrecisionRecallCurve


class AdaptiveThreshold(Metric):
class AdaptiveThreshold(PrecisionRecallCurve):
"""Optimal F1 Metric.
Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the
predicted anomaly scores.
"""

def __init__(self, default_value: float = 0.5, **kwargs):
super().__init__(**kwargs)
super().__init__(num_classes=1, **kwargs)

self.precision_recall_curve = PrecisionRecallCurve(num_classes=1)
self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable
self.value = torch.tensor(default_value) # pylint: disable=not-callable

# pylint: disable=arguments-differ
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore
"""Update the precision-recall curve metric."""
self.precision_recall_curve.update(preds, target)

def compute(self) -> torch.Tensor:
"""Compute the threshold that yields the optimal F1 score.
Expand All @@ -39,7 +33,7 @@ def compute(self) -> torch.Tensor:
recall: torch.Tensor
thresholds: torch.Tensor

precision, recall, thresholds = self.precision_recall_curve.compute()
precision, recall, thresholds = super().compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
if thresholds.dim() == 0:
# special case where recall is 1.0 even for the highest threshold.
Expand All @@ -48,7 +42,3 @@ def compute(self) -> torch.Tensor:
else:
self.value = thresholds[torch.argmax(f1_score)]
return self.value

def reset(self) -> None:
"""Reset the metric."""
self.precision_recall_curve.reset()

0 comments on commit a03e592

Please sign in to comment.