Skip to content

Commit

Permalink
✨ Add DRAEM Model (#344)
Browse files Browse the repository at this point in the history
* initial implementation of DRAEM algo

* fix preprocessor

* fix config and update license

* add imgaug to requirements

* fix inputs of reconstruction loss

* add readme

* add architecture image

* use shorter license header

* allow multiple image extensions

* replace loss functions

* ssim_kornia_loss -> ssim_loss

* update license headers

* clarify anomaly source dataset in readme

* Fix model registration

* remove comments

* update variable names

* move helpers to utils directory

* add init

* update third party software

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
djdameln and samet-akcay committed Jun 8, 2022
1 parent 1c992fb commit 1dcbe1a
Show file tree
Hide file tree
Showing 17 changed files with 1,124 additions and 8 deletions.
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
model_list: List[str] = ["cflow", "dfkde", "dfm", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "draem", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if config.model.name in model_list:
Expand Down
29 changes: 29 additions & 0 deletions anomalib/models/draem/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Copyright (c) 2022 Intel Corporation
SPDX-License-Identifier: Apache-2.0

Some files in this folder are based on the original DRAEM implementation by VitjanZ

Original license:
----------------

MIT License

Copyright (c) 2021 VitjanZ

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
22 changes: 22 additions & 0 deletions anomalib/models/draem/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection

This is the implementation of the [DRAEM](https://arxiv.org/pdf/2108.07610v2.pdf) paper.

Model Type: Segmentation

## Description

DRAEM is a reconstruction based algorithm that consists of a reconstructive subnetwork and a discriminative subnetwork. DRAEM is trained on simulated anomaly images, generated by augmenting normal input images from the training set with a random Perlin noise mask extracted from an unrelated source of image data. The reconstructive subnetwork is an autoencoder architecture that is trained to reconstruct the original input images from the augmented images. The reconstructive submodel is trained using a combination of L2 loss and Structural Similarity loss. The input of the discriminative subnetwork consists of the channel-wise concatenation of the (augmented) input image and the output of the reconstructive subnetwork. The output of the discriminative subnetwork is an anomaly map that contains the predicted anomaly scores for each pixel location. The discriminative subnetwork is trained using Focal Loss.

For optimal results, DRAEM requires specifying the path to a folder of image data that will be used as the source of the anomalous pixel regions in the simulated anomaly images. The path can be specified by editing the value of the `model.anomaly_source_path` parameter in the `config.yaml` file. The authors of the original paper recommend using the [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/) dataset as anomaly source.

## Architecture
![DRAEM Architecture](../../../docs/source/images/draem/architecture.png "DRAEM Architecture")

## Usage

`python tools/train.py --model draem`

## Benchmark

Benchmarking results are not yet available for this algorithm. Please check again later.
8 changes: 8 additions & 0 deletions anomalib/models/draem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""DRAEM model."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import DraemLightning

__all__ = ["DraemLightning"]
102 changes: 102 additions & 0 deletions anomalib/models/draem/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
dataset:
name: mvtec #options: [mvtec, btech, folder]
format: mvtec
path: ./datasets/MVTec
category: bottle
task: segmentation
image_size: 256
train_batch_size: 8
test_batch_size: 32
num_workers: 8
transform_config:
train: ./anomalib/models/draem/transform_config.yaml
val: ./anomalib/models/draem/transform_config.yaml
create_validation_set: false
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16

model:
name: draem
anomaly_source_path: null # optional, e.g. ./datasets/dtd
lr: 0.0001
early_stopping:
patience: 50
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [none, min_max, cdf]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 3
pixel_default: 3
adaptive: true

project:
seed: 42
path: ./results
log_images_to: ["local"]
logger: false # options: [tensorboard, wandb, csv] or combinations.

optimization:
openvino:
apply: false

# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1
default_root_dir: null
detect_anomaly: false
deterministic: false
devices: 1
enable_checkpointing: true
enable_model_summary: true
enable_progress_bar: true
fast_dev_run: false
gpus: null # Set automatically
gradient_clip_val: 0
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
log_gpu_memory: null
max_epochs: 100
max_steps: -1
max_time: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: null
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: true
strategy: null
sync_batchnorm: false
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
108 changes: 108 additions & 0 deletions anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection.
Paper https://arxiv.org/abs/2108.07610
"""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from typing import Optional, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY

from anomalib.models.components import AnomalyModule
from anomalib.models.draem.loss import DraemLoss
from anomalib.models.draem.torch_model import DraemModel
from anomalib.models.draem.utils import Augmenter

logger = logging.getLogger(__name__)

__all__ = ["Draem", "DraemLightning"]


@MODEL_REGISTRY
class Draem(AnomalyModule):
"""DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection.
Args:
anomaly_source_path (Optional[str]): Path to folder that contains the anomaly source images. Random noise will
be used if left empty.
"""

def __init__(self, anomaly_source_path: Optional[str] = None):
super().__init__()

self.augmenter = Augmenter(anomaly_source_path)
self.model = DraemModel()
self.loss = DraemLoss()

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of DRAEM.
Feeds the original image and the simulated anomaly
image through the network and computes the training loss.
Args:
batch (Dict[str, Any]): Batch containing image filename, image, label and mask
Returns:
Loss dictionary
"""
input_image = batch["image"]
# Apply corruption to input image
augmented_image, anomaly_mask = self.augmenter.augment_batch(input_image)
# Generate model prediction
reconstruction, prediction = self.model(augmented_image)
# Compute loss
loss = self.loss(input_image, reconstruction, anomaly_mask, prediction)
return {"loss": loss}

def validation_step(self, batch, _):
"""Validation step of DRAEM. The Softmax predictions of the anomalous class are used as anomaly map.
Args:
batch: Batch of input images
Returns:
Dictionary to which predicted anomaly maps have been added.
"""
prediction = self.model(batch["image"])
batch["anomaly_maps"] = prediction[:, 1, :, :]
return batch


class DraemLightning(Draem):
"""DRÆM: A discriminatively trained reconstruction embedding for surface anomaly detection.
Args:
hparams (Union[DictConfig, ListConfig]): Model parameters
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(anomaly_source_path=hparams.model.anomaly_source_path)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)

def configure_callbacks(self):
"""Configure model-specific callbacks.
Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure callback method will be
deprecated, and callbacks will be configured from either
config.yaml file or from CLI.
"""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self): # pylint: disable=arguments-differ
"""Configure the Adam optimizer."""
return torch.optim.Adam(params=self.model.parameters(), lr=self.hparams.model.lr)
29 changes: 29 additions & 0 deletions anomalib/models/draem/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Loss function for the DRAEM model implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from kornia.losses import FocalLoss, SSIMLoss
from torch import nn


class DraemLoss(nn.Module):
"""Overall loss function of the DRAEM model.
The total loss consists of the sum of the L2 loss and Focal loss between the reconstructed image and the input
image, and the Structural Similarity loss between the predicted and GT anomaly masks.
"""

def __init__(self):
super().__init__()

self.l2_loss = nn.modules.loss.MSELoss()
self.focal_loss = FocalLoss(alpha=1, reduction="mean")
self.ssim_loss = SSIMLoss(window_size=11)

def forward(self, input_image, reconstruction, anomaly_mask, prediction):
"""Compute the loss over a batch for the DRAEM model."""
l2_loss_val = self.l2_loss(reconstruction, input_image)
focal_loss_val = self.focal_loss(prediction, anomaly_mask.squeeze(1).long())
ssim_loss_val = self.ssim_loss(reconstruction, input_image) * 2
return l2_loss_val + ssim_loss_val + focal_loss_val
Loading

0 comments on commit 1dcbe1a

Please sign in to comment.