Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add min-max normalization #53

Merged
merged 15 commits into from
Jan 5, 2022
20 changes: 13 additions & 7 deletions anomalib/core/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.callbacks import Callback, ModelCheckpoint

from .compress import CompressModelCallback
from .min_max_normalization import MinMaxNormalizationCallback
from .model_loader import LoadModelCallback
from .normalization import AnomalyScoreNormalizationCallback
from .save_to_csv import SaveToCSVCallback
Expand Down Expand Up @@ -51,17 +52,22 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
load_model = LoadModelCallback(os.path.join(config.project.path, config.model.weight_file))
callbacks.append(load_model)

if "normalize_scores" in config.model.keys() and config.model.normalize_scores:
if config.model.name in ["padim", "stfpm"]:
if not config.optimization.nncf.apply:
callbacks.append(AnomalyScoreNormalizationCallback())
if "normalization_method" in config.model.keys() and not config.model.normalization_method == "none":
if config.model.normalization_method == "cdf":
if config.model.name in ["padim", "stfpm"]:
if not config.optimization.nncf.apply:
callbacks.append(AnomalyScoreNormalizationCallback())
djdameln marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError("CDF Score Normalization is currently not compatible with NNCF.")
else:
raise NotImplementedError("Score Normalization is currently not compatible with NNCF.")
raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
elif config.model.normalization_method == "min_max":
callbacks.append(MinMaxNormalizationCallback())
else:
raise NotImplementedError("Score Normalization is currently supported for PADIM and STFPM only.")
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")

if not config.project.log_images_to == []:
callbacks.append(VisualizerCallback(inputs_are_normalized=config.model.normalize_scores))
callbacks.append(VisualizerCallback(inputs_are_normalized=not config.model.normalization_method == "none"))

if "optimization" in config.keys():
if config.optimization.nncf.apply:
Expand Down
79 changes: 79 additions & 0 deletions anomalib/core/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Anomaly Score Normalization Callback that uses min-max normalization."""
djdameln marked this conversation as resolved.
Show resolved Hide resolved

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Any, Dict

import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT


class MinMaxNormalizationCallback(Callback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the validation batch ends, update the min and max observed values."""
if "anomaly_maps" in outputs.keys():
pl_module.min_max(outputs["anomaly_maps"])
else:
pl_module.min_max(outputs["pred_scores"])

def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the test batch ends, normalizes the predicted scores and anomaly maps."""
self._normalize(outputs, pl_module)

def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Called when the predict batch ends, normalizes the predicted scores and anomaly maps."""
self._normalize(outputs, pl_module)

def _normalize(self, outputs, pl_module):
stats = pl_module.min_max
outputs["pred_scores"] = (
(outputs["pred_scores"] - pl_module.image_threshold.value) / (stats.max - stats.min)
) + 0.5
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = (
(outputs["anomaly_maps"] - pl_module.pixel_threshold.value) / (stats.max - stats.min)
) + 0.5
3 changes: 1 addition & 2 deletions anomalib/core/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ def on_test_batch_end(
assert outputs is not None

if self.inputs_are_normalized:
threshold = 0.5
normalize = False # anomaly maps are already normalized
else:
threshold = pl_module.pixel_threshold.value.item()
normalize = True # raw anomaly maps. Still need to normalize
threshold = pl_module.pixel_metrics.F1.threshold

for (filename, image, true_mask, anomaly_map) in zip(
outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"]
Expand Down
3 changes: 2 additions & 1 deletion anomalib/core/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .adaptive_threshold import AdaptiveThreshold
from .anomaly_score_distribution import AnomalyScoreDistribution
from .auroc import AUROC
from .min_max import MinMax
from .optimal_f1 import OptimalF1

__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution"]
__all__ = ["AUROC", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
43 changes: 43 additions & 0 deletions anomalib/core/metrics/min_max.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Module that tracks the min and max values of the observations in each batch."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Tuple

import torch
from torch import Tensor
from torchmetrics import Metric


class MinMax(Metric):
"""Track the min and max values of the observations in each batch."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("min", torch.tensor(float("inf")), persistent=True) # pylint: disable=not-callable
self.add_state("max", torch.tensor(float("-inf")), persistent=True) # pylint: disable=not-callable

self.min = torch.tensor(float("inf")) # pylint: disable=not-callable
self.max = torch.tensor(float("-inf")) # pylint: disable=not-callable

# pylint: disable=arguments-differ
def update(self, predictions: Tensor) -> None: # type: ignore
"""Update the min and max values."""
self.max = torch.max(self.max, torch.max(predictions))
self.min = torch.min(self.min, torch.min(predictions))

def compute(self) -> Tuple[Tensor, Tensor]:
"""Return min and max values."""
return self.min, self.max
12 changes: 9 additions & 3 deletions anomalib/core/model/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from torch import Tensor, nn
from torchmetrics import F1, MetricCollection

from anomalib.core.metrics import AUROC, AdaptiveThreshold, AnomalyScoreDistribution
from anomalib.core.metrics import (
AUROC,
AdaptiveThreshold,
AnomalyScoreDistribution,
MinMax,
)


class AnomalyModule(pl.LightningModule):
Expand All @@ -47,6 +52,7 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default)

self.training_distribution = AnomalyScoreDistribution()
self.min_max = MinMax()
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

self.model: nn.Module

Expand Down Expand Up @@ -141,8 +147,8 @@ def _compute_adaptive_threshold(self, outputs):
else:
self.pixel_threshold.value = self.image_threshold.value

self.image_metrics.F1.threshold = self.image_threshold.value
self.pixel_metrics.F1.threshold = self.pixel_threshold.value
self.image_metrics.F1.threshold = self.image_threshold.value.item()
self.pixel_metrics.F1.threshold = self.pixel_threshold.value.item()

def _collect_outputs(self, image_metric, pixel_metric, outputs):
for output in outputs:
Expand Down
5 changes: 5 additions & 0 deletions anomalib/core/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,11 @@ def post_process(self, predictions: np.ndarray, meta_data: Optional[Dict] = None
anomaly_map = predictions.squeeze()
pred_score = anomaly_map.reshape(-1).max()

# min max normalization
if "min" in meta_data and "max" in meta_data:
anomaly_map = ((anomaly_map - meta_data["pixel_threshold"]) / (meta_data["max"] - meta_data["min"])) + 0.5
pred_score = ((pred_score - meta_data["image_threshold"]) / (meta_data["max"] - meta_data["min"])) + 0.5

# standardize pixel scores
if "pixel_mean" in meta_data.keys() and "pixel_std" in meta_data.keys():
anomaly_map = np.log(anomaly_map)
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ model:
confidence_threshold: 0.5
pre_processing: scale
n_components: 16
normalize_scores: false # currently not supported for this model
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
adaptive: true
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model:
pca_level: 0.97
score_type: fre # nll: for Gaussian modeling, fre: pca feature reconstruction error
project_path: ./results
normalize_scores: false # currently not supported for this model
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
adaptive: true
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/padim/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ model:
- layer2
- layer3
metric: auc
normalize_scores: true
normalization_method: min_max # options: [none, min_max, cdf]
threshold:
image_default: 3
pixel_default: 3
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ model:
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
normalize_scores: false # currently not supported for this model
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
pixel_default: 0
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/stfpm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ model:
patience: 3
metric: pixel_AUROC
mode: max
normalize_scores: false
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
pixel_default: 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ def test_normalizer():
config.dataset.path = get_dataset_path(config.dataset.path)
config.model.threshold.adaptive = True

# run with normalization
config.model.normalize_scores = True
# run without normalization
config.model.normalization_method = "none"
seed_everything(42)
results_with_normalization = run_train_test(config)
results_without_normalization = run_train_test(config)

# run with cdf normalization
config.model.normalization_method = "cdf"
seed_everything(42)
results_with_cdf_normalization = run_train_test(config)

# run without normalization
config.model.normalize_scores = False
config.model.normalization_method = "min_max"
seed_everything(42)
results_without_normalization = run_train_test(config)
results_with_minmax_normalization = run_train_test(config)

# performance should be the same
for metric in ["image_AUROC", "image_F1"]:
assert results_without_normalization[0][metric] == results_with_normalization[0][metric]
assert results_without_normalization[0][metric] == results_with_cdf_normalization[0][metric]
assert results_without_normalization[0][metric] == results_with_minmax_normalization[0][metric]