Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

➕ Add FastFlow Model #336

Merged
merged 30 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
838494c
Add relative imports to freia __init__
samet-akcay May 26, 2022
09f5d9e
🎬 Started fastflow implementation
samet-akcay May 26, 2022
89e599c
🎬 Started fastflow implementation
samet-akcay May 26, 2022
0723896
Create anomaly map generator for fastflow
samet-akcay May 28, 2022
52fcb8b
Add fastflow to the list of available trainable models
samet-akcay May 28, 2022
80ffcde
Add trainer parameters to fastflow.config
samet-akcay May 28, 2022
ca34609
Modify FastflowModel and FastFlowLightningModule
samet-akcay May 28, 2022
4546bdf
Log performance metrics to the progress bar
samet-akcay May 28, 2022
c3a21ce
Added configs for other backbones
samet-akcay May 28, 2022
545e29b
Added readme
samet-akcay May 29, 2022
77eda6d
Modify lighning logger message
samet-akcay May 29, 2022
d51516f
Added architecture figure
samet-akcay May 29, 2022
e43d160
Added fastflow results
samet-akcay May 29, 2022
98b1af7
fix typo in fastflow readme
samet-akcay May 29, 2022
57df26b
Silence mypy issues caused by different signature in training_step
samet-akcay May 31, 2022
1322150
Added DeiT results
samet-akcay May 31, 2022
65a2c43
Added fastflow tests
samet-akcay May 31, 2022
1e891a0
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay May 31, 2022
1f9677b
Fix torchmetrics to v0.8
samet-akcay Jun 1, 2022
dbd5f26
Addressed PR comments.
samet-akcay Jun 1, 2022
652c592
Resolve conflicts
samet-akcay Jun 1, 2022
b684657
Set num_workers=8 for every algos
samet-akcay Jun 1, 2022
8ad1d27
Update readme.md and config file
samet-akcay Jun 1, 2022
c53adeb
Merged development and resolved conflicts
samet-akcay Jun 1, 2022
5725366
Remove leftover todo and update benchmark
samet-akcay Jun 1, 2022
4ad4886
Update README files
samet-akcay Jun 1, 2022
fc0ffde
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jun 3, 2022
920bb4a
🗑 Remove fastflow from nightly tests
samet-akcay Jun 3, 2022
a30f4e6
➕ Add @gathierry to the third-party-programs.txt
samet-akcay Jun 3, 2022
729e043
Add fastflow to inferencer tests
samet-akcay Jun 6, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,12 @@ python tools/train.py --model padim
where the currently available models are:

- [CFlow](anomalib/models/cflow)
- [DFM](anomalib/models/dfm)
- [DFKDE](anomalib/models/dfkde)
- [FastFlow](anomalib/models/fastflow)
- [PatchCore](anomalib/models/patchcore)
- [PADIM](anomalib/models/padim)
- [STFPM](anomalib/models/stfpm)
- [DFM](anomalib/models/dfm)
- [DFKDE](anomalib/models/dfkde)
- [GANomaly](anomalib/models/ganomaly)

### Custom Dataset
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
model_list: List[str] = ["cflow", "dfkde", "dfm", "ganomaly", "padim", "patchcore", "stfpm"]
model_list: List[str] = ["cflow", "dfkde", "dfm", "fastflow", "ganomaly", "padim", "patchcore", "stfpm"]
model: AnomalyModule

if config.model.name in model_list:
Expand Down
6 changes: 4 additions & 2 deletions anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def _outputs_to_cpu(self, output):

def _log_metrics(self):
"""Log computed performance metrics."""
self.log_dict(self.image_metrics)
if self.pixel_metrics.update_called:
self.log_dict(self.pixel_metrics)
self.log_dict(self.pixel_metrics, prog_bar=True)
self.log_dict(self.image_metrics, prog_bar=False)
else:
self.log_dict(self.image_metrics, prog_bar=True)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions anomalib/models/components/freia/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@
# Copyright (c) 2018-2022 Lynton Ardizzone, Visual Learning Lab Heidelberg.
# SPDX-License-Identifier: MIT
#

from .framework import SequenceINN
from .modules import AllInOneBlock

__all__ = ["SequenceINN", "AllInOneBlock"]
2 changes: 1 addition & 1 deletion anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
image_size: 256
train_batch_size: 32
test_batch_size: 32
num_workers: 36
num_workers: 8
transform_config:
train: null
val: null
Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dataset:
image_size: 256
train_batch_size: 32
test_batch_size: 32
num_workers: 36
num_workers: 8
transform_config:
train: null
val: null
Expand Down
124 changes: 124 additions & 0 deletions anomalib/models/fastflow/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# FastFlow: Unsupervised Anomaly Detection and Localization via 2D Normalizing Flows

This is the implementation of the [FastFlow](https://arxiv.org/abs/2111.07677) paper. This code is developed inspired by [https://github.com/gathierry/FastFlow](https://github.com/gathierry/FastFlow).

Model Type: Segmentation

## Description

FastFlow is a two-dimensional normalizing flow-based probability distribution estimator. It can be used as a plug-in module with any deep feature extractor, such as ResNet and vision transformer, for unsupervised anomaly detection and localisation. In the training phase, FastFlow learns to transform the input visual feature into a tractable distribution, and in the inference phase, it assesses the likelihood of identifying anomalies.

## Architecture

![FastFlow Architecture](../../../docs/source/images/fastflow/architecture.jpg "FastFlow Architecture")

## Usage

`python tools/train.py --model fastflow`

## Benchmark

All results gathered with seed `0`.

## [MVTec AD Dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad)

---
**NOTE**

When the numbers are produced, early stopping callback (patience: 3) is used. It might be possible to achieve higher-metrics by increasing the patience.

---

### Image-Level AUC

| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 1.000 | 1.000 | 0.905 | 0.986 |
| Cable | 0.891 | 0.962 | 0.942 | 0.839 |
| Capsule | 0.900 | 0.963 | 0.819 | 0.913 |
| Carpet | 0.979 | 0.994 | 0.999 | 1.000 |
| Grid | 0.988 | 1.000 | 0.991 | 0.979 |
| Hazelnut | 0.846 | 0.994 | 0.900 | 0.948 |
| Leather | 1.000 | 0.999 | 0.999 | 0.991 |
| Metal_nut | 0.963 | 0.995 | 0.911 | 0.963 |
| Pill | 0.916 | 0.942 | 0.910 | 0.916 |
| Screw | 0.521 | 0.839 | 0.705 | 0.791 |
| Tile | 0.967 | 1.000 | 0.993 | 0.998 |
| Toothbrush | 0.844 | 0.836 | 0.850 | 0.886 |
| Transistor | 0.938 | 0.979 | 0.993 | 0.983 |
| Wood | 0.978 | 0.992 | 0.979 | 0.989 |
| Zipper | 0.878 | 0.951 | 0.981 | 0.977 |
| Average | | | | |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the average score? Also, a minor issue but this format is a bit different from our other models (rows <-> columns). Maybe for consistency we should decide on a single format. Maybe this can be addressed in a different PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some categories that I need to run again to exactly produce the numbers, which is the reason why I left it blank. As agreed, I could add these later.

Regarding the table format, I would prefer this because 15 MVTec categories doesn't fit to the screen when they are placed into columns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could perhaps modify all the tables once we have all the benchmark results merged



### Pixel-Level AUC

| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.983 | 0.986 | 0.991 | 0.984 |
| Cable | 0.954 | 0.972 | 0.973 | 0.981 |
| Capsule | 0.985 | 0.990 | 0.979 | 0.991 |
| Carpet | 0.983 | 0.991 | 0.991 | 0.992 |
| Grid | 0.985 | 0.992 | 0.980 | 0.979 |
| Hazelnut | 0.953 | 0.980 | 0.989 | 0.993 |
| Leather | 0.996 | 0.996 | 0.995 | 0.996 |
| Metal_nut | 0.972 | 0.988 | 0.978 | 0.973 |
| Pill | 0.972 | 0.976 | 0.985 | 0.992 |
| Screw | 0.926 | 0.966 | 0.945 | 0.979 |
| Tile | 0.944 | 0.966 | 0.951 | 0.960 |
| Toothbrush | 0.979 | 0.980 | 0.985 | 0.992 |
| Transistor | 0.964 | 0.971 | 0.949 | 0.960 |
| Wood | 0.956 | 0.941 | 0.952 | 0.954 |
| Zipper | 0.965 | 0.985 | 0.978 | 0.979 |
| Average | | | | |



### Image F1 Score
| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.976 | 0.952 | 0.741 | 0.977 |
| Cable | 0.851 | 0.918 | 0.848 | 0.835 |
| Capsule | 0.937 | 0.952 | 0.905 | 0.928 |
| Carpet | 0.955 | 0.983 | 0.994 | 0.973 |
| Grid | 0.941 | 0.974 | 0.982 | 0.948 |
| Hazelnut | 0.852 | 0.979 | 0.828 | 0.900 |
| Leather | 0.995 | 0.974 | 0.995 | 0.963 |
| Metal_nut | 0.925 | 0.969 | 0.899 | 0.916 |
| Pill | 0.946 | 0.949 | 0.949 | 0.616 |
| Screw | 0.853 | 0.893 | 0.868 | 0.979 |
| Tile | 0.947 | 0.994 | 0.976 | 0.994 |
| Toothbrush | 0.875 | 0.870 | 0.833 | 0.833 |
| Transistor | 0.779 | 0.854 | 0.873 | 0.909 |
| Wood | 0.983 | 0.968 | 0.944 | 0.967 |
| Zipper | 0.921 | 0.975 | 0.958 | 0.933 |
| Average | | | | |

### Pixel F1 Score
| | ResNet-18 | Wide ResNet50 | DeiT | CaiT |
| ---------- | :-------: | :-----------: | :---: | :---: |
| Bottle | 0.670 | 0.733 | 0.753 | 0.725 |
| Cable | 0.547 | 0.564 | 0.487 | 0.608 |
| Capsule | 0.472 | 0.490 | 0.399 | 0.497 |
| Carpet | 0.573 | 0.598 | 0.586 | 0.606 |
| Grid | 0.412 | 0.481 | 0.393 | 0.410 |
| Hazelnut | 0.522 | 0.545 | 0.643 | 0.706 |
| Leather | 0.560 | 0.576 | 0.504 | 0.516 |
| Metal_nut | 0.728 | 0.754 | 0.766 | 0.737 |
| Pill | 0.589 | 0.611 | 0.709 | 0.617 |
| Screw | 0.061 | 0.660 | 0.269 | 0.370 |
| Tile | 0.569 | 0.660 | 0.655 | 0.660 |
| Toothbrush | 0.479 | 0.481 | 0.524 | 0.535 |
| Transistor | 0.558 | 0.573 | 0.527 | 0.567 |
| Wood | 0.557 | 0.488 | 0.614 | 0.572 |
| Zipper | 0.492 | 0.621 | 0.522 | 0.504 |
| Average | | | | |


### Sample Results

![Sample Result 1](../../../docs/source/images/fastflow/results/0.png "Sample Result 1")

![Sample Result 2](../../../docs/source/images/fastflow/results/1.png "Sample Result 2")

![Sample Result 3](../../../docs/source/images/fastflow/results/2.png "Sample Result 3")
10 changes: 10 additions & 0 deletions anomalib/models/fastflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""FastFlow Algorithm Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from .lightning_model import Fastflow, FastflowLightning
from .torch_model import FastflowLoss, FastflowModel

__all__ = ["FastflowModel", "FastflowLoss", "FastflowLightning", "Fastflow"]
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
49 changes: 49 additions & 0 deletions anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""FastFlow Anomaly Map Generator Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import List, Tuple, Union

import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor


class AnomalyMapGenerator:
"""Generate Anomaly Heatmap."""

def __init__(self, input_size: Union[ListConfig, Tuple]):
self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

def __call__(self, hidden_variables: List[Tensor]) -> Tensor:
"""Generate Anomaly Heatmap.

This implementation generates the heatmap based on the flow maps
computed from the normalizing flow (NF) FastFlow blocks. Each block
yields a flow map, which overall is stacked and averaged to an anomaly
map.

Args:
hidden_variables (List[Tensor]): List of hidden variables from each NF FastFlow block.

Returns:
Tensor: Anomaly Map.
"""
flow_maps: List[Tensor] = []
for hidden_variable in hidden_variables:
log_prob = -torch.mean(hidden_variable**2, dim=1, keepdim=True) * 0.5
prob = torch.exp(log_prob)
flow_map = F.interpolate(
input=-prob,
size=self.input_size,
mode="bilinear",
align_corners=False,
)
flow_maps.append(flow_map)
flow_maps = torch.stack(flow_maps, dim=-1)
anomaly_map = torch.mean(flow_maps, dim=-1)

return anomaly_map
102 changes: 102 additions & 0 deletions anomalib/models/fastflow/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
dataset:
name: mvtec #options: [mvtec, btech, folder]
format: mvtec
path: ./datasets/MVTec
task: segmentation
category: bottle
image_size: 256 # options: [256, 256, 448, 384] - for each supported backbone
train_batch_size: 32
test_batch_size: 32
num_workers: 8
transform_config:
train: null
val: null
create_validation_set: false
tiling:
apply: false
tile_size: null
stride: null
remove_border_count: 0
use_random_tiling: False
random_tile_count: 16

model:
name: fastflow
backbone: resnet18 # options: [resnet18, wide_resnet50_2, cait_m48_448, deit_base_distilled_patch16_384]
flow_steps: 8 # options: [8, 8, 20, 20] - for each supported backbone
hidden_ratio: 1.0 # options: [1.0, 1.0, 0.16, 0.16] - for each supported backbone
conv3x3_only: True # options: [True, False, False, False] - for each supported backbone
lr: 0.001
weight_decay: 0.00001
early_stopping:
patience: 3
metric: pixel_AUROC
mode: max
normalization_method: min_max # options: [null, min_max, cdf]

metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 0
pixel_default: 0
adaptive: true

project:
seed: 0
path: ./results
log_images_to: [local]
logger: false # options: [tensorboard, wandb, csv] or combinations.

# PL Trainer Args. Don't add extra parameter here.
trainer:
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: false
auto_scale_batch_size: false
auto_select_gpus: false
benchmark: false
check_val_every_n_epoch: 1 # Don't validate before extracting features.
default_root_dir: null
detect_anomaly: false
deterministic: false
devices: 1
enable_checkpointing: true
enable_model_summary: true
enable_progress_bar: true
fast_dev_run: false
gpus: null # Set automatically
gradient_clip_val: 0
ipus: null
limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
log_gpu_memory: null
max_epochs: 500
max_steps: -1
max_time: null
min_epochs: null
min_steps: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
num_nodes: 1
num_processes: null
num_sanity_val_steps: 0
overfit_batches: 0.0
plugins: null
precision: 32
profiler: null
reload_dataloaders_every_n_epochs: 0
replace_sampler_ddp: true
strategy: null
sync_batchnorm: false
tpu_cores: null
track_grad_norm: -1
val_check_interval: 1.0 # Don't validate before extracting features.
Loading