Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DRAEM Model #344

Merged
merged 22 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
model_list: List[str] = ["cflow", "dfkde", "dfm", "ganomaly", "padim", "patchcore", "stfpm"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "draem", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if config.model.name in model_list:
Expand Down
29 changes: 29 additions & 0 deletions anomalib/models/draem/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
Copyright (c) 2022 Intel Corporation
SPDX-License-Identifier: Apache-2.0

Some files in this folder are based on the original DRAEM implementation by VitjanZ

Original license:
----------------

MIT License

Copyright (c) 2021 VitjanZ

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
21 changes: 21 additions & 0 deletions anomalib/models/draem/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# DRÆM – A discriminatively trained reconstruction embedding for surface anomaly detection

This is the implementation of the [DRAEM](https://arxiv.org/pdf/2108.07610v2.pdf) paper.

Model Type: Segmentation

## Description

DRAEM is a reconstruction based algorithm that consists of a reconstructive subnetwork and a discriminative subnetwork. DRAEM is trained on simulated anomaly images, generated by augmenting normal input images from the training set with a random Perlin noise mask extracted from an unrelated source of image data. The reconstructive subnetwork is an autoencoder architecture that is trained to reconstruct the original input images from the augmented images. The reconstructive submodel is trained using a combination of L2 loss and Structural Similarity loss. The input of the discriminative subnetwork consists of the channel-wise concatenation of the (augmented) input image and the output of the reconstructive subnetwork. The output of the discriminative subnetwork is an anomaly map that contains the predicted anomaly scores for each pixel location. The discriminative subnetwork is trained using Focal Loss.


## Architecture
![DRAEM Architecture](../../../docs/source/images/draem/architecture.png "DRAEM Architecture")

## Usage

`python tools/train.py --model draem`

## Benchmark

Benchmarking results are not yet available for this algorithm. Please check again later.
8 changes: 8 additions & 0 deletions anomalib/models/draem/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""DRAEM model."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import DraemLightning

__all__ = ["DraemLightning"]
147 changes: 147 additions & 0 deletions anomalib/models/draem/augmenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""Augmenter module to generates out-of-distribution samples for the DRAEM implementation."""
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

# Original Code
# Copyright (c) 2022 VitjanZ
# https://github.com/VitjanZ/DRAEM.
# SPDX-License-Identifier: MIT
#
# Modified
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import glob
import random
from typing import Optional, Tuple

import cv2
import imgaug.augmenters as iaa
import numpy as np
import torch
from torch import Tensor
from torchvision.datasets.folder import IMG_EXTENSIONS

from anomalib.models.draem.perlin import rand_perlin_2d_np


class Augmenter:
"""Class that generates noisy augmentations of input images.

Args:
anomaly_source_path (Optional[str]): Path to a folder of images that will be used as source of the anomalous
noise. If not specified, random noise will be used instead.
"""

def __init__(self, anomaly_source_path: Optional[str] = None):

self.anomaly_source_paths = []
if anomaly_source_path is not None:
for img_ext in IMG_EXTENSIONS:
self.anomaly_source_paths.extend(glob.glob(anomaly_source_path + "/**/*" + img_ext, recursive=True))

self.augmenters = [
iaa.GammaContrast((0.5, 2.0), per_channel=True),
iaa.MultiplyAndAddToBrightness(mul=(0.8, 1.2), add=(-30, 30)),
iaa.pillike.EnhanceSharpness(),
iaa.AddToHueAndSaturation((-50, 50), per_channel=True),
iaa.Solarize(0.5, threshold=(32, 128)),
iaa.Posterize(),
iaa.Invert(),
iaa.pillike.Autocontrast(),
iaa.pillike.Equalize(),
iaa.Affine(rotate=(-45, 45)),
]
self.rot = iaa.Sequential([iaa.Affine(rotate=(-90, 90))])
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

def rand_augmenter(self) -> iaa.Sequential:
"""Selects 3 random transforms that will be applied to the anomaly source images.

Returns:
A selection of 3 transforms.
"""
aug_ind = np.random.choice(np.arange(len(self.augmenters)), 3, replace=False)
aug = iaa.Sequential([self.augmenters[aug_ind[0]], self.augmenters[aug_ind[1]], self.augmenters[aug_ind[2]]])
return aug

def generate_perturbation(
self, height: int, width: int, anomaly_source_path: Optional[str]
) -> Tuple[np.ndarray, np.ndarray]:
"""Generate an image containing a random anomalous perturbation using a source image.

Args:
height (int): height of the generated image.
width: (int): width of the generated image.
anomaly_source_path (Optional[str]): Path to an image file. If not provided, random noise will be used
instead.

Returns:
Image containing a random anomalous perturbation, and the corresponding ground truth anomaly mask.
"""
# Generate random perlin noise
perlin_scale = 6
min_perlin_scale = 0

perlin_scalex = 2 ** random.randint(min_perlin_scale, perlin_scale)
perlin_scaley = 2 ** random.randint(min_perlin_scale, perlin_scale)

perlin_noise = rand_perlin_2d_np((height, width), (perlin_scalex, perlin_scaley))
perlin_noise = self.rot(image=perlin_noise)

# Create mask from perlin noise
mask = np.where(perlin_noise > 0.5, np.ones_like(perlin_noise), np.zeros_like(perlin_noise))
mask = np.expand_dims(mask, axis=2).astype(np.float32)

# Load anomaly source image
if anomaly_source_path:
anomaly_source_img = cv2.imread(anomaly_source_path)
anomaly_source_img = cv2.resize(anomaly_source_img, dsize=(width, height))
else: # if no anomaly source is specified, we use the perlin noise as anomalous source
anomaly_source_img = np.expand_dims(perlin_noise, 2).repeat(3, 2)
anomaly_source_img = (anomaly_source_img * 255).astype(np.uint8)

# Augment anomaly source image
aug = self.rand_augmenter()
anomaly_img_augmented = aug(image=anomaly_source_img)

# Create anomalous perturbation that we will apply to the image
perturbation = anomaly_img_augmented.astype(np.float32) * mask / 255.0
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved

return perturbation, mask

def augment_batch(self, batch: Tensor) -> Tuple[Tensor, Tensor]:
"""Generate anomalous augmentations for a batch of input images.

Args:
batch (Tensor): Batch of input images

Returns:
- Augmented image to which anomalous perturbations have been added.
- Ground truth masks corresponding to the anomalous perturbations.
"""
batch_size, channels, height, width = batch.shape

# Collect perturbations
perturbations_list = []
masks_list = []
for _ in range(batch_size):
if random.random() > 0.5: # include 50% normal samples
perturbations_list.append(torch.zeros((channels, height, width)))
masks_list.append(torch.zeros((1, height, width)))
else:
anomaly_source_path = (
random.sample(self.anomaly_source_paths, 1)[0] if len(self.anomaly_source_paths) > 0 else None
)
perturbation, mask = self.generate_perturbation(height, width, anomaly_source_path)
perturbations_list.append(Tensor(perturbation).permute((2, 0, 1)))
masks_list.append(Tensor(mask).permute((2, 0, 1)))

perturbations = torch.stack(perturbations_list).to(batch.device)
masks = torch.stack(masks_list).to(batch.device)

# Apply perturbations batch wise
beta = torch.rand(batch_size) * 0.8
beta = beta.view(batch_size, 1, 1, 1).expand_as(batch).to(batch.device)

augmented_batch = batch * (1 - masks) + (1 - beta) * perturbations + beta * batch * (masks)

return augmented_batch, masks
102 changes: 102 additions & 0 deletions anomalib/models/draem/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
dataset:
name: mvtec #options: [mvtec, btech, folder]
format: mvtec
path: ./datasets/MVTec
category: bottle
task: segmentation
image_size: 256
train_batch_size: 8
test_batch_size: 32
num_workers: 8
transform_config:
train: ./anomalib/models/draem/transform_config.yaml
val: ./anomalib/models/draem/transform_config.yaml
create_validation_set: false
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16

model:
name: draem
anomaly_source_path: null # optional, e.g. ./datasets/dtd
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
lr: 0.0001
early_stopping:
patience: 50
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [none, min_max, cdf]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 3
pixel_default: 3
adaptive: true

project:
seed: 42
path: ./results
log_images_to: ["local"]
logger: false # options: [tensorboard, wandb, csv] or combinations.

optimization:
openvino:
apply: false

# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1
default_root_dir: null
detect_anomaly: false
deterministic: false
devices: 1
enable_checkpointing: true
enable_model_summary: true
enable_progress_bar: true
fast_dev_run: false
gpus: null # Set automatically
gradient_clip_val: 0
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
log_gpu_memory: null
max_epochs: 100
max_steps: -1
max_time: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: null
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: true
strategy: null
sync_batchnorm: false
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0
Loading