Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor fixes: Update callbacks to AnomalyModule #208

Merged
merged 3 commits into from
Apr 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions anomalib/utils/callbacks/cdf_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.distributions import LogNormal

from anomalib.models import get_model
from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.cdf import normalize, standardize


Expand All @@ -32,12 +33,12 @@ def __init__(self):
self.image_dist: Optional[LogNormal] = None
self.pixel_dist: Optional[LogNormal] = None

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: AnomalyModule) -> None:
"""Called when the validation starts after training.

Use the current model to compute the anomaly score distributions
Expand All @@ -49,7 +50,7 @@ def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn
def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand All @@ -61,7 +62,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand All @@ -74,7 +75,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
Expand Down Expand Up @@ -120,7 +121,7 @@ def _standardize_batch(outputs: STEP_OUTPUT, pl_module) -> None:
)

@staticmethod
def _normalize_batch(outputs: STEP_OUTPUT, pl_module: pl.LightningModule) -> None:
def _normalize_batch(outputs: STEP_OUTPUT, pl_module: AnomalyModule) -> None:
outputs["pred_scores"] = normalize(outputs["pred_scores"], pl_module.image_threshold.value)
if "anomaly_maps" in outputs.keys():
outputs["anomaly_maps"] = normalize(outputs["anomaly_maps"], pl_module.pixel_threshold.value)
9 changes: 5 additions & 4 deletions anomalib/utils/callbacks/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,22 @@
from pytorch_lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT

from anomalib.models.components import AnomalyModule
from anomalib.post_processing.normalization.min_max import normalize


class MinMaxNormalizationCallback(Callback):
"""Callback that normalizes the image-level and pixel-level anomaly scores using min-max normalization."""

def on_test_start(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_start(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Called when the test begins."""
pl_module.image_metrics.F1.threshold = 0.5
pl_module.pixel_metrics.F1.threshold = 0.5

def on_validation_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
Expand All @@ -49,7 +50,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: STEP_OUTPUT,
_batch: Any,
_batch_idx: int,
Expand All @@ -61,7 +62,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Dict,
_batch: Any,
_batch_idx: int,
Expand Down
6 changes: 4 additions & 2 deletions anomalib/utils/callbacks/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# and limitations under the License.

import torch
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback

from anomalib.models.components import AnomalyModule


class LoadModelCallback(Callback):
Expand All @@ -24,7 +26,7 @@ class LoadModelCallback(Callback):
def __init__(self, weights_path):
self.weights_path = weights_path

def on_test_start(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
def on_test_start(self, trainer, pl_module: AnomalyModule) -> None: # pylint: disable=W0613
"""Call when the test begins.

Loads the model weights from ``weights_path`` into the PyTorch module.
Expand Down
7 changes: 3 additions & 4 deletions anomalib/utils/callbacks/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# and limitations under the License.

import os
from typing import Tuple, cast
from typing import Tuple

from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Callback

from anomalib.deploy import export_convert
from anomalib.models.components import AnomalyModule
Expand All @@ -39,15 +39,14 @@ def __init__(self, input_size: Tuple[int, int], dirpath: str, filename: str):
self.dirpath = dirpath
self.filename = filename

def on_train_end(self, trainer, pl_module: LightningModule) -> None: # pylint: disable=W0613
def on_train_end(self, trainer, pl_module: AnomalyModule) -> None: # pylint: disable=W0613
"""Call when the train ends.

Converts the model to ``onnx`` format and then calls OpenVINO's model optimizer to get the
``.xml`` and ``.bin`` IR files.
"""
os.makedirs(self.dirpath, exist_ok=True)
onnx_path = os.path.join(self.dirpath, self.filename + ".onnx")
pl_module = cast(AnomalyModule, pl_module)
export_convert(
model=pl_module,
input_size=self.input_size,
Expand Down
6 changes: 3 additions & 3 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _add_images(
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
Expand Down Expand Up @@ -149,15 +149,15 @@ def on_test_batch_end(
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()

def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
def on_test_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Sync logs.

Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
ensures that all images appear as part of the same step.

Args:
_trainer (pl.Trainer): Pytorch Lightning trainer (unused)
pl_module (pl.LightningModule): Anomaly module
pl_module (AnomalyModule): Anomaly module
"""
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
pl_module.logger.save()
10 changes: 8 additions & 2 deletions tests/nightly/models/test_model_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,19 @@ def _test_metrics(self, trainer, config, model, datamodule):
threshold = thresholds[config.model.name][config.dataset.category]
if "optimization" in config.keys() and config.optimization.nncf.apply:
threshold = threshold.nncf
if not (np.isclose(results["image_AUROC"], threshold["image_AUROC"], rtol=0.02) or (results["image_AUROC"] >= threshold["image_AUROC"])):
if not (
np.isclose(results["image_AUROC"], threshold["image_AUROC"], rtol=0.02)
or (results["image_AUROC"] >= threshold["image_AUROC"])
):
raise AssertionError(
f"results['image_AUROC']:{results['image_AUROC']} >= threshold['image_AUROC']:{threshold['image_AUROC']}"
)

if config.dataset.task == "segmentation":
if not (np.isclose(results["pixel_AUROC"] ,threshold["pixel_AUROC"], rtol=0.02) or (results["pixel_AUROC"] >= threshold["pixel_AUROC"])):
if not (
np.isclose(results["pixel_AUROC"], threshold["pixel_AUROC"], rtol=0.02)
or (results["pixel_AUROC"] >= threshold["pixel_AUROC"])
):
raise AssertionError(
f"results['pixel_AUROC']:{results['pixel_AUROC']} >= threshold['pixel_AUROC']:{threshold['pixel_AUROC']}"
)
Expand Down