Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Anomalib CLI #378

Merged
merged 58 commits into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
d76acf6
Refactored MVTec datamodule
samet-akcay Jun 8, 2022
3b0dd3c
🏷️ Rename BTech datamodule
samet-akcay Jun 8, 2022
c376215
🏷️ Rename Folder datamodule
samet-akcay Jun 8, 2022
7cec5c7
Create datamodule jupyter notebook
samet-akcay Jun 9, 2022
0537b06
Added mvtec into jupyter notebook
samet-akcay Jun 9, 2022
39d7de7
Apply black formatter
samet-akcay Jun 9, 2022
e5d09bb
Finished MVTec
samet-akcay Jun 9, 2022
422fa1c
🏷 Rename the name of the folder
samet-akcay Jun 9, 2022
10b93f0
🏷 Renamed anomaly-datamodule.ipynb to mvtec.ipynb
samet-akcay Jun 9, 2022
599bc46
➕ Created BTech notebook
samet-akcay Jun 9, 2022
f8dc4a7
🚚 Move the main description from mvtec to README.md
samet-akcay Jun 9, 2022
ae1b081
Polish btech jupyter notebook
samet-akcay Jun 9, 2022
d71ce81
➕Created folder jupyter notebook
samet-akcay Jun 9, 2022
a57f74e
Merge branch 'development' into nb/sa/add-dataset-module-notebook
samet-akcay Jun 9, 2022
4f0688c
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jun 13, 2022
9986ffc
Addres PR comments.
samet-akcay Jun 13, 2022
4d18d82
Merge branch 'nb/sa/add-dataset-module-notebook' of github.com:openvi…
samet-akcay Jun 13, 2022
68a9354
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jun 14, 2022
a6885cd
Format the notebooks
samet-akcay Jun 14, 2022
02fd4fd
Added black, isort, flake8 and pylint
samet-akcay Jun 14, 2022
a17f58b
➕ Add mdformat to dev requirements
samet-akcay Jun 14, 2022
b1b1c3d
🛠 Update the nnotebook markdowns
samet-akcay Jun 14, 2022
3b4e46e
Configured pre-commit for notebooks
samet-akcay Jun 14, 2022
f7274ec
➕ Add DATAMODULE_REGISTRY to the datamodules.
samet-akcay Jun 15, 2022
196fcbd
➕ Added relative imports to __init__ modules.
samet-akcay Jun 15, 2022
24f83c5
Register callbacks via @CALLBACK_REGISTRY
samet-akcay Jun 15, 2022
6c5285c
➕ Added missing callback imports
samet-akcay Jun 15, 2022
4bf0ae4
➕ Add AnomalibCLI
samet-akcay Jun 15, 2022
fe9b6a1
➕ Add AnomalibCLI
samet-akcay Jun 15, 2022
a2bad29
➕ Add `normalization_method` to `MetricsConfigurationCallback`
samet-akcay Jun 15, 2022
ca3046c
➕ Add padim config file
samet-akcay Jun 15, 2022
09234f2
🛠 Fix pytorch-lightning version to add extra
samet-akcay Jun 15, 2022
57d8322
Update torchmetrics version requirement
samet-akcay Jun 16, 2022
9ca510d
Add anomalib CLI entrypoint
samet-akcay Jun 16, 2022
e4c0ef9
Remove --save_images flag from the cli
samet-akcay Jun 16, 2022
f023f6c
Refactor trainer.py and create main() function
samet-akcay Jun 16, 2022
8be525c
Updated anomalib entrypoint. Removed trainer.py from cli
samet-akcay Jun 16, 2022
afc8c72
Add logger to cli
samet-akcay Jun 16, 2022
e827f7d
➕ Add patchcore config
samet-akcay Jun 16, 2022
5594557
➕ Add Cflow CLI config
samet-akcay Jun 16, 2022
42a389d
🚚 Move padim and patchcore configs to config directory
samet-akcay Jun 16, 2022
790f588
Add cflow, dfkde, dfm, draem, fastflow, padim and patchcore
samet-akcay Jun 16, 2022
1bd62a2
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jun 16, 2022
4653b54
➕ Add ganomaly config
samet-akcay Jun 16, 2022
39ab361
add reverse distillation and stfpm configs
samet-akcay Jun 16, 2022
063a394
Resolved merge conflicts
samet-akcay Jun 17, 2022
6bc525c
Update notebooks
samet-akcay Jun 17, 2022
90af89d
Merge branch 'development' of github.com:openvinotoolkit/anomalib int…
samet-akcay Jun 17, 2022
487f8a5
Create the project directory only during traaining.
samet-akcay Jun 20, 2022
8cbf682
🏷 Renamed config directory to configs
samet-akcay Jun 21, 2022
b606eed
Fix the project directory creation logic.
samet-akcay Jun 21, 2022
d140396
Add the training CLI command to the README.md
samet-akcay Jun 21, 2022
bc304e5
🛠 Fix incorrect statement.
samet-akcay Jun 21, 2022
13cfde1
🧹 Cleanup
samet-akcay Jun 22, 2022
ba19444
Fix v.4.0 to v.0.4.0
samet-akcay Jun 23, 2022
67ada2c
Added more description to AnomalibCLI docstring.
samet-akcay Jun 23, 2022
0a37c2c
Added TODO comments to address later.
samet-akcay Jun 23, 2022
6bb7f41
Extracted methods to simplify `before_instantiate_classes` method
samet-akcay Jun 23, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pip install -e .
```

## Training
### ⚠️ Anomalib < v.0.4.0

By default [`python tools/train.py`](https://gitlab-icv.inn.intel.com/algo_rnd_team/anomaly/-/blob/development/train.py)
runs [PADIM](https://arxiv.org/abs/2011.08785) model on `leather` category from the [MVTec AD](https://www.mvtec.com/company/research/datasets/mvtec-ad) [(CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/) dataset.
Expand Down Expand Up @@ -141,8 +142,22 @@ dataset:
use_random_tiling: False
random_tile_count: 16
```
## Inference

### ⚠️ Anomalib > v.0.4.0 Beta - Subject to Change
We introduce a new CLI approach that uses [PyTorch Lightning CLI](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html). To train a model using the new CLI, one would call the following:
```bash
anomalib fit --config <path/to/new/config/file>
```

For instance, to train a [PatchCore](https://github.com/openvinotoolkit/anomalib/tree/development/anomalib/models/patchcore) model, the following command would be run:
```bash
anomalib fit --config ./configs/model/patchcore.yaml
```

The new CLI approach offers a lot more flexibility, details of which are explained in the [documentation](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html).

## Inference
### ⚠️ Anomalib < v.0.4.0
Anomalib contains several tools that can be used to perform inference with a trained model. The script in [`tools/inference`](tools/inference.py) contains an example of how the inference tools can be used to generate a prediction for an input image.

If the specified weight path points to a PyTorch Lightning checkpoint file (`.ckpt`), inference will run in PyTorch. If the path points to an ONNX graph (`.onnx`) or OpenVINO IR (`.bin` or `.xml`), inference will run in OpenVINO.
Expand Down
2 changes: 2 additions & 0 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import pandas as pd
from pandas.core.frame import DataFrame
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -257,6 +258,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
return item


@DATAMODULE_REGISTRY
class BTech(LightningDataModule):
"""BTechDataModule Lightning Data Module."""

Expand Down
2 changes: 2 additions & 0 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
from pandas.core.frame import DataFrame
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -301,6 +302,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
return item


@DATAMODULE_REGISTRY
class Folder(LightningDataModule):
"""Folder Lightning Data Module."""

Expand Down
2 changes: 2 additions & 0 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import pandas as pd
from pandas.core.frame import DataFrame
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch import Tensor
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -280,6 +281,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
return item


@DATAMODULE_REGISTRY
class MVTec(LightningDataModule):
"""MVTec AD Lightning Data Module."""

Expand Down
23 changes: 23 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,30 @@
from omegaconf import DictConfig, ListConfig
from torch import load

from anomalib.models.cflow import Cflow
from anomalib.models.components import AnomalyModule
from anomalib.models.dfkde import Dfkde
from anomalib.models.dfm import Dfm
from anomalib.models.draem import Draem
from anomalib.models.fastflow import Fastflow
from anomalib.models.ganomaly import Ganomaly
from anomalib.models.padim import Padim
from anomalib.models.patchcore import Patchcore
from anomalib.models.reverse_distillation import ReverseDistillation
from anomalib.models.stfpm import Stfpm

__all__ = [
"Cflow",
"Dfkde",
"Dfm",
"Draem",
"Fastflow",
"Ganomaly",
"Padim",
"Patchcore",
"ReverseDistillation",
"Stfpm",
]

logger = logging.getLogger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions anomalib/models/cflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import CflowLightning
from .lightning_model import Cflow, CflowLightning

__all__ = ["CflowLightning"]
__all__ = ["Cflow", "CflowLightning"]
48 changes: 26 additions & 22 deletions anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
coupling_blocks: int = 8,
clamp_alpha: float = 1.9,
permute_soft: bool = False,
lr: float = 0.0001,
):
super().__init__()

Expand All @@ -64,6 +65,31 @@ def __init__(
permute_soft=permute_soft,
)
self.automatic_optimization = False
# TODO: LR should be part of optimizer in config.yaml! Since cflow has custom
# optimizer this is to be addressed later.
self.learning_rate = lr

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.

Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.

Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.learning_rate,
)
return optimizer

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.
Expand Down Expand Up @@ -193,25 +219,3 @@ def configure_callbacks(self):
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.

Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.

Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer
4 changes: 2 additions & 2 deletions anomalib/models/dfkde/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import DfkdeLightning
from .lightning_model import Dfkde, DfkdeLightning

__all__ = ["DfkdeLightning"]
__all__ = ["Dfkde", "DfkdeLightning"]
4 changes: 2 additions & 2 deletions anomalib/models/dfm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import DfmLightning
from .lightning_model import Dfm, DfmLightning

__all__ = ["DfmLightning"]
__all__ = ["Dfm", "DfmLightning"]
4 changes: 2 additions & 2 deletions anomalib/models/draem/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import DraemLightning
from .lightning_model import Draem, DraemLightning

__all__ = ["DraemLightning"]
__all__ = ["Draem", "DraemLightning"]
2 changes: 1 addition & 1 deletion anomalib/models/fastflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from .loss import FastflowLoss
from .torch_model import FastflowModel

__all__ = ["FastflowModel", "FastflowLoss", "FastflowLightning", "Fastflow"]
__all__ = ["FastflowModel", "FastflowLoss", "Fastflow", "FastflowLightning"]
4 changes: 2 additions & 2 deletions anomalib/models/ganomaly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import GanomalyLightning
from .lightning_model import Ganomaly, GanomalyLightning

__all__ = ["GanomalyLightning"]
__all__ = ["Ganomaly", "GanomalyLightning"]
60 changes: 36 additions & 24 deletions anomalib/models/ganomaly/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def __init__(
wadv: int = 1,
wcon: int = 50,
wenc: int = 1,
lr: float = 0.0002,
beta1: float = 0.5,
beta2: float = 0.999,
):
super().__init__()

Expand All @@ -82,11 +85,41 @@ def __init__(
self.generator_loss = GeneratorLoss(wadv, wcon, wenc)
self.discriminator_loss = DiscriminatorLoss()

# TODO: LR should be part of optimizer in config.yaml! Since ganomaly has custom
# optimizer this is to be addressed later.
self.learning_rate = lr
self.beta1 = beta1
self.beta2 = beta2

def _reset_min_max(self):
"""Resets min_max scores."""
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable

def configure_optimizers(self) -> List[optim.Optimizer]:
"""Configures optimizers for each decoder.

Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.

Returns:
Optimizer: Adam optimizer for each decoder
"""
optimizer_d = optim.Adam(
self.model.discriminator.parameters(),
lr=self.learning_rate,
betas=(self.beta1, self.beta2),
)
optimizer_g = optim.Adam(
self.model.generator.parameters(),
lr=self.learning_rate,
betas=(self.beta1, self.beta2),
)
return [optimizer_d, optimizer_g]

def training_step(self, batch, _, optimizer_idx): # pylint: disable=arguments-differ
"""Training step.

Expand Down Expand Up @@ -191,6 +224,9 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]) -> None:
wadv=hparams.model.wadv,
wcon=hparams.model.wcon,
wenc=hparams.model.wenc,
lr=hparams.model.lr,
beta1=hparams.model.beta1,
beta2=hparams.model.beta2,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)
Expand All @@ -210,27 +246,3 @@ def configure_callbacks(self):
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> List[optim.Optimizer]:
"""Configures optimizers for each decoder.

Note:
This method is used for the existing CLI.
When PL CLI is introduced, configure optimizers method will be
deprecated, and optimizers will be configured from either
config.yaml file or from CLI.

Returns:
Optimizer: Adam optimizer for each decoder
"""
optimizer_d = optim.Adam(
self.model.discriminator.parameters(),
lr=self.hparams.model.lr,
betas=(self.hparams.model.beta1, self.hparams.model.beta2),
)
optimizer_g = optim.Adam(
self.model.generator.parameters(),
lr=self.hparams.model.lr,
betas=(self.hparams.model.beta1, self.hparams.model.beta2),
)
return [optimizer_d, optimizer_g]
4 changes: 2 additions & 2 deletions anomalib/models/padim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import PadimLightning
from .lightning_model import Padim, PadimLightning

__all__ = ["PadimLightning"]
__all__ = ["Padim", "PadimLightning"]
4 changes: 2 additions & 2 deletions anomalib/models/patchcore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import PatchcoreLightning
from .lightning_model import Patchcore, PatchcoreLightning

__all__ = ["PatchcoreLightning"]
__all__ = ["Patchcore", "PatchcoreLightning"]
4 changes: 2 additions & 2 deletions anomalib/models/reverse_distillation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import ReverseDistillationLightning
from .lightning_model import ReverseDistillation, ReverseDistillationLightning

__all__ = ["ReverseDistillationLightning"]
__all__ = ["ReverseDistillation", "ReverseDistillationLightning"]
Loading