diff --git a/anomalib/models/components/base/anomaly_module.py b/anomalib/models/components/base/anomaly_module.py index e3ff975693..170bb4aba6 100644 --- a/anomalib/models/components/base/anomaly_module.py +++ b/anomalib/models/components/base/anomaly_module.py @@ -124,6 +124,8 @@ def test_epoch_end(self, outputs): self._log_metrics() def _compute_adaptive_threshold(self, outputs): + self.image_threshold.reset() + self.pixel_threshold.reset() self._collect_outputs(self.image_threshold, self.pixel_threshold, outputs) self.image_threshold.compute() if "mask" in outputs[0].keys() and "anomaly_maps" in outputs[0].keys(): @@ -134,7 +136,8 @@ def _compute_adaptive_threshold(self, outputs): self.image_metrics.set_threshold(self.image_threshold.value.item()) self.pixel_metrics.set_threshold(self.pixel_threshold.value.item()) - def _collect_outputs(self, image_metric, pixel_metric, outputs): + @staticmethod + def _collect_outputs(image_metric, pixel_metric, outputs): for output in outputs: image_metric.cpu() image_metric.update(output["pred_scores"], output["label"].int()) @@ -142,15 +145,16 @@ def _collect_outputs(self, image_metric, pixel_metric, outputs): pixel_metric.cpu() pixel_metric.update(output["anomaly_maps"], output["mask"].int()) - def _post_process(self, outputs): + @staticmethod + def _post_process(outputs): """Compute labels based on model predictions.""" if "pred_scores" not in outputs and "anomaly_maps" in outputs: outputs["pred_scores"] = ( outputs["anomaly_maps"].reshape(outputs["anomaly_maps"].shape[0], -1).max(dim=1).values ) - def _outputs_to_cpu(self, output): - # for output in outputs: + @staticmethod + def _outputs_to_cpu(output): for key, value in output.items(): if isinstance(value, Tensor): output[key] = value.cpu() diff --git a/anomalib/utils/metrics/adaptive_threshold.py b/anomalib/utils/metrics/adaptive_threshold.py index 7df33a43b3..fd112433f1 100644 --- a/anomalib/utils/metrics/adaptive_threshold.py +++ b/anomalib/utils/metrics/adaptive_threshold.py @@ -4,10 +4,10 @@ # SPDX-License-Identifier: Apache-2.0 import torch -from torchmetrics import Metric, PrecisionRecallCurve +from torchmetrics import PrecisionRecallCurve -class AdaptiveThreshold(Metric): +class AdaptiveThreshold(PrecisionRecallCurve): """Optimal F1 Metric. Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the @@ -15,17 +15,11 @@ class AdaptiveThreshold(Metric): """ def __init__(self, default_value: float = 0.5, **kwargs): - super().__init__(**kwargs) + super().__init__(num_classes=1, **kwargs) - self.precision_recall_curve = PrecisionRecallCurve(num_classes=1) self.add_state("value", default=torch.tensor(default_value), persistent=True) # pylint: disable=not-callable self.value = torch.tensor(default_value) # pylint: disable=not-callable - # pylint: disable=arguments-differ - def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore - """Update the precision-recall curve metric.""" - self.precision_recall_curve.update(preds, target) - def compute(self) -> torch.Tensor: """Compute the threshold that yields the optimal F1 score. @@ -39,7 +33,7 @@ def compute(self) -> torch.Tensor: recall: torch.Tensor thresholds: torch.Tensor - precision, recall, thresholds = self.precision_recall_curve.compute() + precision, recall, thresholds = super().compute() f1_score = (2 * precision * recall) / (precision + recall + 1e-10) if thresholds.dim() == 0: # special case where recall is 1.0 even for the highest threshold. @@ -48,7 +42,3 @@ def compute(self) -> torch.Tensor: else: self.value = thresholds[torch.argmax(f1_score)] return self.value - - def reset(self) -> None: - """Reset the metric.""" - self.precision_recall_curve.reset()