Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AnomalyModule and LightningModules to explicitly define class arguments. #315

Merged
merged 42 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
e7142d2
Created metrics callback to access the metric params in the module
samet-akcay May 13, 2022
9020d20
Added metrics callback to get_callback function
samet-akcay May 13, 2022
71bd755
🗑 Removed hparams from visualizer callback
samet-akcay May 13, 2022
abc6415
🗑 Removed hparams from visualizer callback
samet-akcay May 14, 2022
9bf0666
☑️ Check if image and pixel metrics exist in anomaly module in normal…
samet-akcay May 14, 2022
9a400c3
➕ Add tiling callback to setup the tiler without configuring the model.
samet-akcay May 14, 2022
a882d1a
➕ Add log_images_to flag to VisualizerCallback
samet-akcay May 14, 2022
46c100b
🗑 Removed hparams from AnomalyModule and PadimLightning
samet-akcay May 14, 2022
02197cb
🛠 Fix dummy module to be initialized with explicit arguments
samet-akcay May 14, 2022
074b4c4
➕ Refactor dfm model
samet-akcay May 16, 2022
7d863cb
Make thresholding parameters optional in AnomalyModule
samet-akcay May 16, 2022
e513a75
🗑 Remove pixel threshold from DFM
samet-akcay May 16, 2022
547f5c6
create default image and pixel threshold attributes in AnomalyModule
samet-akcay May 16, 2022
c898475
🗑 Remove hparams from DFM model and ➕explicit arguments.
samet-akcay May 16, 2022
8ad0434
🗑 Remove hparams from Patchcore model and ➕explicit arguments.
samet-akcay May 16, 2022
b7cfcfd
🗑 Remove hparams from Stfpm model and ➕explicit arguments.
samet-akcay May 16, 2022
6bbfeb3
🗑 Remove hparams from GANomaly model and ➕explicit arguments.
samet-akcay May 16, 2022
39a48c4
🗑 Remove hparams from Cflow model and ➕explicit arguments.
samet-akcay May 16, 2022
c03c316
Removed normalization and Fix the tests
samet-akcay May 16, 2022
67f9eed
📦 wrapped PadimLightning model from Padim
samet-akcay May 17, 2022
2773882
📦 wrapped PatchcoreLightning model from Patchcore
samet-akcay May 17, 2022
0e41628
📦 wrapped StfpmLightning model from Stfpm
samet-akcay May 17, 2022
ac964f2
📦 wrapped CflowLightning model from Cflow
samet-akcay May 17, 2022
fc7a4db
📦 wrapped CflowLightning model from Cflow
samet-akcay May 17, 2022
821c368
📦 wrapped DfkdeLightning model from Dfkde
samet-akcay May 17, 2022
8b4c612
📦 wrapped Dfm and Ganomaly models
samet-akcay May 17, 2022
4542911
⏪ Revert DummyModule and visualizer tests
samet-akcay May 17, 2022
98c8f80
Fix visualizer tests
samet-akcay May 17, 2022
a4a9225
Fix normalization tests
samet-akcay May 17, 2022
b227fd4
Add metrics to lightning module
samet-akcay May 18, 2022
ec9ed75
🗑 Remove metrics callback
samet-akcay May 18, 2022
c53d82b
🛠️ Fix metrics callback and tests
samet-akcay May 18, 2022
3227de4
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay May 18, 2022
761b217
🛠️ Fix metrics callback and tests
samet-akcay May 18, 2022
50ffd38
🛠 Revert on_fit_start to setup in MetricCallback to properly assign t…
samet-akcay May 19, 2022
22ce635
➕ Added docstring to clarify to-be-deprecated methods
samet-akcay May 19, 2022
cf8f54e
🏷 Renamed MetricCallback to MetricConfigurationCallback
samet-akcay May 19, 2022
847fe75
🏷 Renamed TilerCallback to TilerConfigurationCallback
samet-akcay May 19, 2022
57ef6c3
🏷 Renamed MetricCallback to MetricsConfigurationCallback
samet-akcay May 19, 2022
50631d3
🏷 Renamed tiler.py to tiler_configuration.py
samet-akcay May 19, 2022
c2183ae
⏪ Removed instance check in cdf normalization
samet-akcay May 19, 2022
a333dc6
🚚 Move threshold params to metric configuration callback (#328)
djdameln May 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions anomalib/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ def get_configurable_parameters(
config = update_nncf_config(config)

# thresholding
if "pixel_default" not in config.model.threshold.keys():
config.model.threshold.pixel_default = config.model.threshold.image_default
if "metrics" in config.keys():
if "pixel_default" not in config.metrics.threshold.keys():
config.metrics.threshold.pixel_default = config.metrics.threshold.image_default

return config
29 changes: 6 additions & 23 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@

from anomalib.models.components import AnomalyModule

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.lightning_model import PadimLightning # noqa: F401


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
"""Load model from the configuration file.
Expand All @@ -37,10 +33,6 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
`anomalib.models.<model_name>.model.<Model_name>Lightning`
`anomalib.models.stfpm.model.StfpmLightning`
and for OpenVINO
`anomalib.models.<model-name>.model.<Model_name>OpenVINO`
`anomalib.models.stfpm.model.StfpmOpenVINO`
Args:
config (Union[DictConfig, ListConfig]): Config.yaml loaded using OmegaConf
Expand All @@ -50,24 +42,15 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
openvino_model_list: List[str] = ["stfpm"]
torch_model_list: List[str] = ["padim", "stfpm", "dfkde", "dfm", "patchcore", "cflow", "ganomaly"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if "openvino" in config.keys() and config.openvino:
if config.model.name in openvino_model_list:
module = import_module(f"anomalib.models.{config.model.name}.model")
model = getattr(module, f"{config.model.name.capitalize()}OpenVINO")
else:
raise ValueError(f"Unknown model {config.model.name} for OpenVINO model!")
else:
if config.model.name in torch_model_list:
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{config.model.name.capitalize()}Lightning")
else:
raise ValueError(f"Unknown model {config.model.name}!")
if config.model.name in model_list:
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{config.model.name.capitalize()}Lightning")(config)

model = model(config)
else:
raise ValueError(f"Unknown model {config.model.name}!")

if "init_weights" in config.keys() and config.init_weights:
model.load_state_dict(load(os.path.join(config.project.path, config.init_weights))["state_dict"], strict=False)
Expand Down
8 changes: 4 additions & 4 deletions anomalib/models/cflow/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,6 @@ model:
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [null, min_max, cdf]
threshold:
image_default: 0
pixel_default: 0
adaptive: true

metrics:
image:
Expand All @@ -45,6 +41,10 @@ metrics:
pixel:
- F1Score
- AUROC
threshold:
image_default: 0
pixel_default: 0
adaptive: true

project:
seed: 0
Expand Down
121 changes: 91 additions & 30 deletions anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
# and limitations under the License.

import logging
from typing import List, Tuple, Union

import einops
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import optim

from anomalib.models.cflow.torch_model import CflowModel
Expand All @@ -31,45 +34,42 @@

logger = logging.getLogger(__name__)

__all__ = ["CflowLightning"]
__all__ = ["Cflow", "CflowLightning"]


class CflowLightning(AnomalyModule):
@MODEL_REGISTRY
class Cflow(AnomalyModule):
"""PL Lightning Module for the CFLOW algorithm."""

def __init__(self, hparams):
super().__init__(hparams)
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
coupling_blocks: int = 8,
clamp_alpha: float = 1.9,
permute_soft: bool = False,
):
super().__init__()
logger.info("Initializing Cflow Lightning model.")

self.model: CflowModel = CflowModel(hparams)
self.model: CflowModel = CflowModel(
input_size=input_size,
backbone=backbone,
layers=layers,
fiber_batch_size=fiber_batch_size,
decoder=decoder,
condition_vector=condition_vector,
coupling_blocks=coupling_blocks,
clamp_alpha=clamp_alpha,
permute_soft=permute_soft,
)
self.loss_val = 0
self.automatic_optimization = False

def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.
Expand Down Expand Up @@ -159,3 +159,64 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
batch["anomaly_maps"] = self.model(batch["image"])

return batch


class CflowLightning(Cflow):
"""PL Lightning Module for the CFLOW algorithm.
Args:
hparams (Union[DictConfig, ListConfig]): Model params
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None:
super().__init__(
input_size=hparams.model.input_size,
backbone=hparams.model.backbone,
layers=hparams.model.layers,
fiber_batch_size=hparams.dataset.fiber_batch_size,
decoder=hparams.model.decoder,
condition_vector=hparams.model.condition_vector,
coupling_blocks=hparams.model.coupling_blocks,
clamp_alpha=hparams.model.clamp_alpha,
permute_soft=hparams.model.soft_permutation,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)

def configure_callbacks(self):
"""Configure model-specific callbacks.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.
Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer
36 changes: 22 additions & 14 deletions anomalib/models/cflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List, Union
from typing import List, Tuple

import einops
import torch
import torchvision
from omegaconf import DictConfig, ListConfig
from torch import nn

from anomalib.models.cflow.anomaly_map import AnomalyMapGenerator
Expand All @@ -30,25 +29,36 @@
class CflowModel(nn.Module):
"""CFLOW: Conditional Normalizing Flows."""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
def __init__(
self,
input_size: Tuple[int, int],
backbone: str,
layers: List[str],
fiber_batch_size: int = 64,
decoder: str = "freia-cflow",
condition_vector: int = 128,
coupling_blocks: int = 8,
clamp_alpha: float = 1.9,
permute_soft: bool = False,
):
super().__init__()

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.fiber_batch_size = hparams.dataset.fiber_batch_size
self.condition_vector: int = hparams.model.condition_vector
self.dec_arch = hparams.model.decoder
self.pool_layers = hparams.model.layers
self.backbone = getattr(torchvision.models, backbone)
self.fiber_batch_size = fiber_batch_size
self.condition_vector: int = condition_vector
self.dec_arch = decoder
self.pool_layers = layers

self.encoder = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.pool_layers)
self.pool_dims = self.encoder.out_dims
self.decoders = nn.ModuleList(
[
cflow_head(
condition_vector=self.condition_vector,
coupling_blocks=hparams.model.coupling_blocks,
clamp_alpha=hparams.model.clamp_alpha,
coupling_blocks=coupling_blocks,
clamp_alpha=clamp_alpha,
n_features=pool_dim,
permute_soft=hparams.model.soft_permutation,
permute_soft=permute_soft,
)
for pool_dim in self.pool_dims
]
Expand All @@ -58,9 +68,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
for parameters in self.encoder.parameters():
parameters.requires_grad = False

self.anomaly_map_generator = AnomalyMapGenerator(
image_size=tuple(hparams.model.input_size), pool_layers=self.pool_layers
)
self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(input_size), pool_layers=self.pool_layers)

def forward(self, images):
"""Forward-pass images into the network to extract encoder features and compute probability.
Expand Down
35 changes: 15 additions & 20 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,46 @@
# and limitations under the License.

from abc import ABC
from typing import Any, List, Optional, Union
from typing import Any, List, Optional

import pytorch_lightning as pl
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import Tensor, nn

from anomalib.utils.metrics import (
AdaptiveThreshold,
AnomalibMetricCollection,
AnomalyScoreDistribution,
MinMax,
get_metrics,
)


class AnomalyModule(pl.LightningModule, ABC):
"""AnomalyModule to train, validate, predict and test images.
Acts as a base class for all the Anomaly Modules in the library.
Args:
params (Union[DictConfig, ListConfig]): Configuration
"""

def __init__(self, params: Union[DictConfig, ListConfig]):

def __init__(self):
super().__init__()
# Force the type for hparams so that it works with OmegaConfig style of accessing
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(params)
self.save_hyperparameters()
self.model: nn.Module
self.loss: Tensor
self.callbacks: List[Callback]

self.image_threshold = AdaptiveThreshold(self.hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(self.hparams.model.threshold.pixel_default).cpu()
self.adaptive_threshold: bool

self.image_threshold = AdaptiveThreshold().cpu()
self.pixel_threshold = AdaptiveThreshold().cpu()

self.training_distribution = AnomalyScoreDistribution().cpu()
self.min_max = MinMax().cpu()

self.model: nn.Module

# metrics
self.image_metrics, self.pixel_metrics = get_metrics(self.hparams)
self.image_metrics.set_threshold(self.hparams.model.threshold.image_default)
self.pixel_metrics.set_threshold(self.hparams.model.threshold.pixel_default)
# Create placeholders for image and pixel metrics.
# If set from the config file, MetricsConfigurationCallback will
# create the metric collections upon setup.
self.image_metrics: AnomalibMetricCollection
self.pixel_metrics: AnomalibMetricCollection

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down Expand Up @@ -128,7 +123,7 @@ def validation_epoch_end(self, outputs):
Args:
outputs: Batch of outputs from the validation step
"""
if self.hparams.model.threshold.adaptive:
if self.adaptive_threshold:
self._compute_adaptive_threshold(outputs)
self._collect_outputs(self.image_metrics, self.pixel_metrics, outputs)
self._log_metrics()
Expand Down
Loading