Skip to content

Commit

Permalink
🏗 Refactor AnomalyModule and LightningModules to explicitly defin…
Browse files Browse the repository at this point in the history
…e class arguments. (#315)

* Created metrics callback to access the metric params in the module

* Added metrics callback to get_callback function

* 🗑 Removed hparams from visualizer callback

* 🗑 Removed hparams from visualizer callback

* ☑️ Check if image and pixel metrics exist in anomaly module in normalization callbacks

* ➕ Add tiling callback to setup the tiler without configuring the model.

* ➕ Add log_images_to flag to VisualizerCallback

* 🗑 Removed hparams from AnomalyModule and PadimLightning

* 🛠 Fix dummy module to be initialized with explicit arguments

* ➕ Refactor dfm model

* Make thresholding parameters optional in AnomalyModule

* 🗑  Remove pixel threshold from DFM

* create default image and pixel threshold attributes in AnomalyModule

* 🗑  Remove hparams from DFM model and ➕explicit arguments.

* 🗑  Remove hparams from Patchcore model and ➕explicit arguments.

* 🗑  Remove hparams from Stfpm model and ➕explicit arguments.

* 🗑  Remove hparams from GANomaly model and ➕explicit arguments.

* 🗑  Remove hparams from Cflow model and ➕explicit arguments.

* Removed normalization and Fix the tests

* 📦 wrapped PadimLightning model from Padim

* 📦 wrapped PatchcoreLightning model from Patchcore

* 📦 wrapped StfpmLightning model from Stfpm

* 📦 wrapped CflowLightning model from Cflow

* 📦 wrapped CflowLightning model from Cflow

* 📦 wrapped DfkdeLightning model from Dfkde

* 📦 wrapped Dfm and Ganomaly models

* ⏪ Revert DummyModule and visualizer tests

* Fix visualizer tests

* Fix normalization tests

* Add metrics to lightning module

* 🗑 Remove metrics callback

* 🛠️ Fix metrics callback and tests

* 🛠️ Fix metrics callback and tests

* 🛠 Revert on_fit_start to setup in MetricCallback to properly assign threholds in test

* ➕ Added docstring to clarify to-be-deprecated methods

* 🏷 Renamed MetricCallback to MetricConfigurationCallback

* 🏷 Renamed TilerCallback to TilerConfigurationCallback

* 🏷 Renamed MetricCallback to MetricsConfigurationCallback

* 🏷 Renamed tiler.py to tiler_configuration.py

* ⏪ Removed instance check in cdf normalization

* 🚚 Move threshold params to metric configuration callback (#328)

Co-authored-by: Dick Ameln <dick.ameln@intel.com>
  • Loading branch information
samet-akcay and djdameln committed May 24, 2022
1 parent 17b62f8 commit c7d5232
Show file tree
Hide file tree
Showing 35 changed files with 796 additions and 310 deletions.
5 changes: 3 additions & 2 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def get_configurable_parameters(
config = update_nncf_config(config)

# thresholding
if "pixel_default" not in config.model.threshold.keys():
config.model.threshold.pixel_default = config.model.threshold.image_default
if "metrics" in config.keys():
if "pixel_default" not in config.metrics.threshold.keys():
config.metrics.threshold.pixel_default = config.metrics.threshold.image_default

return config
29 changes: 6 additions & 23 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@

from anomalib.models.components import AnomalyModule

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.lightning_model import PadimLightning # noqa: F401


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
"""Load model from the configuration file.
Expand All @@ -37,10 +33,6 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
`anomalib.models.<model_name>.model.<Model_name>Lightning`
`anomalib.models.stfpm.model.StfpmLightning`
and for OpenVINO
`anomalib.models.<model-name>.model.<Model_name>OpenVINO`
`anomalib.models.stfpm.model.StfpmOpenVINO`
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Expand All @@ -50,24 +42,15 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
openvino_model_list: List[str] = ["stfpm"]
torch_model_list: List[str] = ["padim", "stfpm", "dfkde", "dfm", "patchcore", "cflow", "ganomaly"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if "openvino" in config.keys() and config.openvino:
if config.model.name in openvino_model_list:
module = import_module(f"anomalib.models.{config.model.name}.model")
model = getattr(module, f"{config.model.name.capitalize()}OpenVINO")
else:
raise ValueError(f"Unknown model {config.model.name} for OpenVINO model!")
else:
if config.model.name in torch_model_list:
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{config.model.name.capitalize()}Lightning")
else:
raise ValueError(f"Unknown model {config.model.name}!")
if config.model.name in model_list:
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{config.model.name.capitalize()}Lightning")(config)

model = model(config)
else:
raise ValueError(f"Unknown model {config.model.name}!")

if "init_weights" in config.keys() and config.init_weights:
model.load_state_dict(load(os.path.join(config.project.path, config.init_weights))["state_dict"], strict=False)
Expand Down
8 changes: 4 additions & 4 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ model:
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
pixel_default: 0
adaptive: true

metrics:
image:
Expand All @@ -45,6 +41,10 @@ metrics:
pixel:
- F1Score
- AUROC
threshold:
image_default: 0
pixel_default: 0
adaptive: true

project:
seed: 0
Expand Down
121 changes: 91 additions & 30 deletions anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
# and limitations under the License.

import logging
from typing import List, Tuple, Union

import einops
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import optim

from anomalib.models.cflow.torch_model import CflowModel
Expand All @@ -31,45 +34,42 @@

logger = logging.getLogger(__name__)

__all__ = ["CflowLightning"]
__all__ = ["Cflow", "CflowLightning"]


class CflowLightning(AnomalyModule):
@MODEL_REGISTRY
class Cflow(AnomalyModule):
"""PL Lightning Module for the CFLOW algorithm."""

def __init__(self, hparams):
super().__init__(hparams)
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
coupling_blocks: int = 8,
clamp_alpha: float = 1.9,
permute_soft: bool = False,
):
super().__init__()
logger.info("Initializing Cflow Lightning model.")

self.model: CflowModel = CflowModel(hparams)
self.model: CflowModel = CflowModel(
input_size=input_size,
backbone=backbone,
layers=layers,
fiber_batch_size=fiber_batch_size,
decoder=decoder,
condition_vector=condition_vector,
coupling_blocks=coupling_blocks,
clamp_alpha=clamp_alpha,
permute_soft=permute_soft,
)
self.loss_val = 0
self.automatic_optimization = False

def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.
Expand Down Expand Up @@ -159,3 +159,64 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
batch["anomaly_maps"] = self.model(batch["image"])

return batch


class CflowLightning(Cflow):
"""PL Lightning Module for the CFLOW algorithm.
Args:
hparams (Union[DictConfig, ListConfig]): Model params
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None:
super().__init__(
input_size=hparams.model.input_size,
backbone=hparams.model.backbone,
layers=hparams.model.layers,
fiber_batch_size=hparams.dataset.fiber_batch_size,
decoder=hparams.model.decoder,
condition_vector=hparams.model.condition_vector,
coupling_blocks=hparams.model.coupling_blocks,
clamp_alpha=hparams.model.clamp_alpha,
permute_soft=hparams.model.soft_permutation,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)

def configure_callbacks(self):
"""Configure model-specific callbacks.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer
36 changes: 22 additions & 14 deletions anomalib/models/cflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List, Union
from typing import List, Tuple

import einops
import torch
import torchvision
from omegaconf import DictConfig, ListConfig
from torch import nn

from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
Expand All @@ -30,25 +29,36 @@
class CflowModel(nn.Module):
"""CFLOW: Conditional Normalizing Flows."""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
coupling_blocks: int = 8,
clamp_alpha: float = 1.9,
permute_soft: bool = False,
):
super().__init__()

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.fiber_batch_size = hparams.dataset.fiber_batch_size
self.condition_vector: int = hparams.model.condition_vector
self.dec_arch = hparams.model.decoder
self.pool_layers = hparams.model.layers
self.backbone = getattr(torchvision.models, backbone)
self.fiber_batch_size = fiber_batch_size
self.condition_vector: int = condition_vector
self.dec_arch = decoder
self.pool_layers = layers

self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers)
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
cflow_head(
condition_vector=self.condition_vector,
coupling_blocks=hparams.model.coupling_blocks,
clamp_alpha=hparams.model.clamp_alpha,
coupling_blocks=coupling_blocks,
clamp_alpha=clamp_alpha,
n_features=pool_dim,
permute_soft=hparams.model.soft_permutation,
permute_soft=permute_soft,
)
for pool_dim in self.pool_dims
]
Expand All @@ -58,9 +68,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
for parameters in self.encoder.parameters():
parameters.requires_grad = False

self.anomaly_map_generator = AnomalyMapGenerator(
image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers
)
self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(input_size), pool_layers=self.pool_layers)

def forward(self, images):
"""Forward-pass images into the network to extract encoder features and compute probability.
Expand Down
35 changes: 15 additions & 20 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,46 @@
# and limitations under the License.

from abc import ABC
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn

from anomalib.utils.metrics import (
AdaptiveThreshold,
AnomalibMetricCollection,
AnomalyScoreDistribution,
MinMax,
get_metrics,
)


class AnomalyModule(pl.LightningModule, ABC):
"""AnomalyModule to train, validate, predict and test images.
Acts as a base class for all the Anomaly Modules in the library.
Args:
params (Union[DictConfig, ListConfig]): Configuration
"""

def __init__(self, params: Union[DictConfig, ListConfig]):

def __init__(self):
super().__init__()
# Force the type for hparams so that it works with OmegaConfig style of accessing
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(params)
self.save_hyperparameters()
self.model: nn.Module
self.loss: Tensor
self.callbacks: List[Callback]

self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu()
self.adaptive_threshold: bool

self.image_threshold = AdaptiveThreshold().cpu()
self.pixel_threshold = AdaptiveThreshold().cpu()

self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()

self.model: nn.Module

# metrics
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams)
self.image_metrics.set_threshold(self.hparams.model.threshold.image_default)
self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default)
# Create placeholders for image and pixel metrics.
# If set from the config file, MetricsConfigurationCallback will
# create the metric collections upon setup.
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down Expand Up @@ -128,7 +123,7 @@ def validation_epoch_end(self, outputs):
Args:
outputs: Batch of outputs from the validation step
"""
if self.hparams.model.threshold.adaptive:
if self.adaptive_threshold:
self._compute_adaptive_threshold(outputs)
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
Expand Down
Loading

0 comments on commit c7d5232

Please sign in to comment.