Skip to content

Commit

Permalink
➕ Add FastFlow Model (#336)
Browse files Browse the repository at this point in the history
* Add relative imports to freia __init__

* 🎬 Started fastflow implementation

* 🎬 Started fastflow implementation

* Create anomaly map generator for fastflow

* Add fastflow to the list of available trainable models

* Add trainer parameters to fastflow.config

* Modify FastflowModel and FastFlowLightningModule

* Log performance metrics to the progress bar

* Added configs for other backbones

* Added readme

* Modify lighning logger message

* Added architecture figure

* Added fastflow results

* fix typo in fastflow readme

* Silence mypy issues caused by different signature in training_step

* Added DeiT results

* Added fastflow tests

* Fix torchmetrics to v0.8

* Addressed PR comments.

* Set num_workers=8 for every algos

* Update readme.md and config file

* Remove leftover todo and update benchmark

* Update README files

* 🗑  Remove fastflow from nightly tests

* ➕ Add @gathierry to the third-party-programs.txt

* Add fastflow to inferencer tests
  • Loading branch information
samet-akcay committed Jun 7, 2022
1 parent f2cf458 commit 8adfc6c
Show file tree
Hide file tree
Showing 24 changed files with 736 additions and 19 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ python tools/train.py --model padim
where the currently available models are:

- [CFlow](anomalib/models/cflow)
- [DFM](anomalib/models/dfm)
- [DFKDE](anomalib/models/dfkde)
- [FastFlow](anomalib/models/fastflow)
- [PatchCore](anomalib/models/patchcore)
- [PADIM](anomalib/models/padim)
- [STFPM](anomalib/models/stfpm)
- [DFM](anomalib/models/dfm)
- [DFKDE](anomalib/models/dfkde)
- [GANomaly](anomalib/models/ganomaly)

### Custom Dataset
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
model_list: List[str] = ["cflow", "dfkde", "dfm", "ganomaly", "padim", "patchcore", "stfpm"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if config.model.name in model_list:
Expand Down
6 changes: 4 additions & 2 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def _outputs_to_cpu(self, output):

def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
if self.pixel_metrics.update_called:
self.log_dict(self.pixel_metrics)
self.log_dict(self.pixel_metrics, prog_bar=True)
self.log_dict(self.image_metrics, prog_bar=False)
else:
self.log_dict(self.image_metrics, prog_bar=True)
5 changes: 5 additions & 0 deletions anomalib/models/components/freia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@
# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#

from .framework import SequenceINN
from .modules import AllInOneBlock

__all__ = ["SequenceINN", "AllInOneBlock"]
2 changes: 1 addition & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
image_size: 256
train_batch_size: 32
test_batch_size: 32
num_workers: 36
num_workers: 8
transform_config:
train: null
val: null
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
image_size: 256
train_batch_size: 32
test_batch_size: 32
num_workers: 36
num_workers: 8
transform_config:
train: null
val: null
Expand Down
124 changes: 124 additions & 0 deletions anomalib/models/fastflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# FastFlow: Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows

This is the implementation of the [FastFlow](https://arxiv.org/abs/2111.07677) paper. This code is developed inspired by [https://github.com/gathierry/FastFlow](https://github.com/gathierry/FastFlow).

Model Type: Segmentation

## Description

FastFlow is a two-dimensional normalizing flow-based probability distribution estimator. It can be used as a plug-in module with any deep feature extractor, such as ResNet and vision transformer, for unsupervised anomaly detection and localisation. In the training phase, FastFlow learns to transform the input visual feature into a tractable distribution, and in the inference phase, it assesses the likelihood of identifying anomalies.

## Architecture

![FastFlow Architecture](../../../docs/source/images/fastflow/architecture.jpg "FastFlow Architecture")

## Usage

`python tools/train.py --model fastflow`

## Benchmark

All results gathered with seed `0`.

## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

---
**NOTE**

When the numbers are produced, early stopping callback (patience: 3) is used. It might be possible to achieve higher-metrics by increasing the patience.

---

### Image-Level AUC

| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 1.000 | 1.000 | 0.905 | 0.986 |
| Cable | 0.891 | 0.962 | 0.942 | 0.839 |
| Capsule | 0.900 | 0.963 | 0.819 | 0.913 |
| Carpet | 0.979 | 0.994 | 0.999 | 1.000 |
| Grid | 0.988 | 1.000 | 0.991 | 0.979 |
| Hazelnut | 0.846 | 0.994 | 0.900 | 0.948 |
| Leather | 1.000 | 0.999 | 0.999 | 0.991 |
| Metal_nut | 0.963 | 0.995 | 0.911 | 0.963 |
| Pill | 0.916 | 0.942 | 0.910 | 0.916 |
| Screw | 0.521 | 0.839 | 0.705 | 0.791 |
| Tile | 0.967 | 1.000 | 0.993 | 0.998 |
| Toothbrush | 0.844 | 0.836 | 0.850 | 0.886 |
| Transistor | 0.938 | 0.979 | 0.993 | 0.983 |
| Wood | 0.978 | 0.992 | 0.979 | 0.989 |
| Zipper | 0.878 | 0.951 | 0.981 | 0.977 |
| Average | | | | |


### Pixel-Level AUC

| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.983 | 0.986 | 0.991 | 0.984 |
| Cable | 0.954 | 0.972 | 0.973 | 0.981 |
| Capsule | 0.985 | 0.990 | 0.979 | 0.991 |
| Carpet | 0.983 | 0.991 | 0.991 | 0.992 |
| Grid | 0.985 | 0.992 | 0.980 | 0.979 |
| Hazelnut | 0.953 | 0.980 | 0.989 | 0.993 |
| Leather | 0.996 | 0.996 | 0.995 | 0.996 |
| Metal_nut | 0.972 | 0.988 | 0.978 | 0.973 |
| Pill | 0.972 | 0.976 | 0.985 | 0.992 |
| Screw | 0.926 | 0.966 | 0.945 | 0.979 |
| Tile | 0.944 | 0.966 | 0.951 | 0.960 |
| Toothbrush | 0.979 | 0.980 | 0.985 | 0.992 |
| Transistor | 0.964 | 0.971 | 0.949 | 0.960 |
| Wood | 0.956 | 0.941 | 0.952 | 0.954 |
| Zipper | 0.965 | 0.985 | 0.978 | 0.979 |
| Average | | | | |



### Image F1 Score
| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.976 | 0.952 | 0.741 | 0.977 |
| Cable | 0.851 | 0.918 | 0.848 | 0.835 |
| Capsule | 0.937 | 0.952 | 0.905 | 0.928 |
| Carpet | 0.955 | 0.983 | 0.994 | 0.973 |
| Grid | 0.941 | 0.974 | 0.982 | 0.948 |
| Hazelnut | 0.852 | 0.979 | 0.828 | 0.900 |
| Leather | 0.995 | 0.974 | 0.995 | 0.963 |
| Metal_nut | 0.925 | 0.969 | 0.899 | 0.916 |
| Pill | 0.946 | 0.949 | 0.949 | 0.616 |
| Screw | 0.853 | 0.893 | 0.868 | 0.979 |
| Tile | 0.947 | 0.994 | 0.976 | 0.994 |
| Toothbrush | 0.875 | 0.870 | 0.833 | 0.833 |
| Transistor | 0.779 | 0.854 | 0.873 | 0.909 |
| Wood | 0.983 | 0.968 | 0.944 | 0.967 |
| Zipper | 0.921 | 0.975 | 0.958 | 0.933 |
| Average | | | | |

### Pixel F1 Score
| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.670 | 0.733 | 0.753 | 0.725 |
| Cable | 0.547 | 0.564 | 0.487 | 0.608 |
| Capsule | 0.472 | 0.490 | 0.399 | 0.497 |
| Carpet | 0.573 | 0.598 | 0.586 | 0.606 |
| Grid | 0.412 | 0.481 | 0.393 | 0.410 |
| Hazelnut | 0.522 | 0.545 | 0.643 | 0.706 |
| Leather | 0.560 | 0.576 | 0.504 | 0.516 |
| Metal_nut | 0.728 | 0.754 | 0.766 | 0.737 |
| Pill | 0.589 | 0.611 | 0.709 | 0.617 |
| Screw | 0.061 | 0.660 | 0.269 | 0.370 |
| Tile | 0.569 | 0.660 | 0.655 | 0.660 |
| Toothbrush | 0.479 | 0.481 | 0.524 | 0.535 |
| Transistor | 0.558 | 0.573 | 0.527 | 0.567 |
| Wood | 0.557 | 0.488 | 0.614 | 0.572 |
| Zipper | 0.492 | 0.621 | 0.522 | 0.504 |
| Average | | | | |


### Sample Results

![Sample Result 1](../../../docs/source/images/fastflow/results/0.png "Sample Result 1")

![Sample Result 2](../../../docs/source/images/fastflow/results/1.png "Sample Result 2")

![Sample Result 3](../../../docs/source/images/fastflow/results/2.png "Sample Result 3")
10 changes: 10 additions & 0 deletions anomalib/models/fastflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""FastFlow Algorithm Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .lightning_model import Fastflow, FastflowLightning
from .torch_model import FastflowLoss, FastflowModel

__all__ = ["FastflowModel", "FastflowLoss", "FastflowLightning", "Fastflow"]
49 changes: 49 additions & 0 deletions anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""FastFlow Anomaly Map Generator Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import List, Tuple, Union

import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor


class AnomalyMapGenerator:
"""Generate Anomaly Heatmap."""

def __init__(self, input_size: Union[ListConfig, Tuple]):
self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

def __call__(self, hidden_variables: List[Tensor]) -> Tensor:
"""Generate Anomaly Heatmap.
This implementation generates the heatmap based on the flow maps
computed from the normalizing flow (NF) FastFlow blocks. Each block
yields a flow map, which overall is stacked and averaged to an anomaly
map.
Args:
hidden_variables (List[Tensor]): List of hidden variables from each NF FastFlow block.
Returns:
Tensor: Anomaly Map.
"""
flow_maps: List[Tensor] = []
for hidden_variable in hidden_variables:
log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5
prob = torch.exp(log_prob)
flow_map = F.interpolate(
input=-prob,
size=self.input_size,
mode="bilinear",
align_corners=False,
)
flow_maps.append(flow_map)
flow_maps = torch.stack(flow_maps, dim=-1)
anomaly_map = torch.mean(flow_maps, dim=-1)

return anomaly_map
102 changes: 102 additions & 0 deletions anomalib/models/fastflow/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
dataset:
name: mvtec #options: [mvtec, btech, folder]
format: mvtec
path: ./datasets/MVTec
task: segmentation
category: bottle
image_size: 256 # options: [256, 256, 448, 384] - for each supported backbone
train_batch_size: 32
test_batch_size: 32
num_workers: 8
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16

model:
name: fastflow
backbone: resnet18 # options: [resnet18, wide_resnet50_2, cait_m48_448, deit_base_distilled_patch16_384]
flow_steps: 8 # options: [8, 8, 20, 20] - for each supported backbone
hidden_ratio: 1.0 # options: [1.0, 1.0, 0.16, 0.16] - for each supported backbone
conv3x3_only: True # options: [True, False, False, False] - for each supported backbone
lr: 0.001
weight_decay: 0.00001
early_stopping:
patience: 3
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [null, min_max, cdf]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 0
pixel_default: 0
adaptive: true

project:
seed: 0
path: ./results
log_images_to: [local]
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1 # Don't validate before extracting features.
default_root_dir: null
detect_anomaly: false
deterministic: false
devices: 1
enable_checkpointing: true
enable_model_summary: true
enable_progress_bar: true
fast_dev_run: false
gpus: null # Set automatically
gradient_clip_val: 0
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
log_gpu_memory: null
max_epochs: 500
max_steps: -1
max_time: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: null
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: true
strategy: null
sync_batchnorm: false
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0 # Don't validate before extracting features.
Loading

0 comments on commit 8adfc6c

Please sign in to comment.