Skip to content

Commit

Permalink
[Refactor] strings and ints into enum.Enum (#1044)
Browse files Browse the repository at this point in the history
[Refactor: 644] Refactor strings and ints into enum.Enum

This tries to refactor strings and ints into enum.Enum adding:
1. LabelName
2. DirType
3. DataFormat
4. AnomalyMapGenerationMode
5. ImageUpscaleMode
6. VisualizationMode

Fixes #644

Signed-off-by: FanJiangIntel <fan.jiang@intel.com>
Signed-off-by: Kang Wenjing <wenjing.kang@intel.com>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
WenjingKangIntel and samet-akcay committed Apr 28, 2023
1 parent 97b885f commit 2661177
Show file tree
Hide file tree
Showing 19 changed files with 175 additions and 88 deletions.
33 changes: 24 additions & 9 deletions src/anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import logging
from enum import Enum

from omegaconf import DictConfig, ListConfig

Expand All @@ -25,6 +26,20 @@
logger = logging.getLogger(__name__)


class DataFormat(str, Enum):
"""Supported Dataset Types"""

MVTEC = "mvtec"
MVTEC_3D = "mvtec_3d"
BTECH = "btech"
FOLDER = "folder"
FOLDER_3D = "folder_3d"
UCSDPED = "ucsdped"
AVENUE = "avenue"
VISA = "visa"
SHANGHAITECH = "shanghaitech"


def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
"""Get Anomaly Datamodule.
Expand All @@ -43,7 +58,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
if center_crop is not None:
center_crop = (center_crop[0], center_crop[1])

if config.dataset.format.lower() == "mvtec":
if config.dataset.format.lower() == DataFormat.MVTEC:
datamodule = MVTec(
root=config.dataset.path,
category=config.dataset.category,
Expand All @@ -61,7 +76,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "mvtec_3d":
elif config.dataset.format.lower() == DataFormat.MVTEC_3D:
datamodule = MVTec3D(
root=config.dataset.path,
category=config.dataset.category,
Expand All @@ -79,7 +94,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "btech":
elif config.dataset.format.lower() == DataFormat.BTECH:
datamodule = BTech(
root=config.dataset.path,
category=config.dataset.category,
Expand All @@ -97,7 +112,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "folder":
elif config.dataset.format.lower() == DataFormat.FOLDER:
datamodule = Folder(
root=config.dataset.root,
normal_dir=config.dataset.normal_dir,
Expand All @@ -119,7 +134,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "folder_3d":
elif config.dataset.format.lower() == DataFormat.FOLDER_3D:
datamodule = Folder3D(
root=config.dataset.root,
normal_dir=config.dataset.normal_dir,
Expand All @@ -144,7 +159,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "ucsdped":
elif config.dataset.format.lower() == DataFormat.UCSDPED:
datamodule = UCSDped(
root=config.dataset.path,
category=config.dataset.category,
Expand All @@ -162,7 +177,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "avenue":
elif config.dataset.format.lower() == DataFormat.AVENUE:
datamodule = Avenue(
root=config.dataset.path,
gt_dir=config.dataset.gt_dir,
Expand All @@ -180,7 +195,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "visa":
elif config.dataset.format.lower() == DataFormat.VISA:
datamodule = Visa(
root=config.dataset.path,
category=config.dataset.category,
Expand All @@ -198,7 +213,7 @@ def get_datamodule(config: DictConfig | ListConfig) -> AnomalibDataModule:
val_split_mode=config.dataset.val_split_mode,
val_split_ratio=config.dataset.val_split_ratio,
)
elif config.dataset.format.lower() == "shanghaitech":
elif config.dataset.format.lower() == DataFormat.SHANGHAITECH:
datamodule = ShanghaiTech(
root=config.dataset.path,
scene=config.dataset.scene,
Expand Down
5 changes: 3 additions & 2 deletions src/anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from anomalib.data.utils import (
DownloadInfo,
InputNormalizationMethod,
LabelName,
Split,
TestSplitMode,
ValSplitMode,
Expand Down Expand Up @@ -104,8 +105,8 @@ def make_btech_dataset(path: Path, split: str | Split | None = None) -> DataFram
samples.loc[(samples.split == "test") & (samples.label == "ok"), "mask_path"] = ""

# Create label index for normal (0) and anomalous (1) images.
samples.loc[(samples.label == "ok"), "label_index"] = 0
samples.loc[(samples.label != "ok"), "label_index"] = 1
samples.loc[(samples.label == "ok"), "label_index"] = LabelName.NORMAL
samples.loc[(samples.label != "ok"), "label_index"] = LabelName.ABNORMAL
samples.label_index = samples.label_index.astype(int)

# Get the data frame for the split.
Expand Down
28 changes: 16 additions & 12 deletions src/anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from anomalib.data.base import AnomalibDataModule, AnomalibDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import (
DirType,
InputNormalizationMethod,
LabelName,
Split,
TestSplitMode,
ValSplitMode,
Expand Down Expand Up @@ -58,16 +60,16 @@ def make_folder_dataset(

filenames = []
labels = []
dirs = {"normal": normal_dir}
dirs = {DirType.NORMAL: normal_dir}

if abnormal_dir:
dirs = {**dirs, **{"abnormal": abnormal_dir}}
dirs = {**dirs, **{DirType.ABNORMAL: abnormal_dir}}

if normal_test_dir:
dirs = {**dirs, **{"normal_test": normal_test_dir}}
dirs = {**dirs, **{DirType.NORMAL_TEST: normal_test_dir}}

if mask_dir:
dirs = {**dirs, **{"mask_dir": mask_dir}}
dirs = {**dirs, **{DirType.MASK: mask_dir}}

for dir_type, path in dirs.items():
filename, label = _prepare_files_labels(path, dir_type, extensions)
Expand All @@ -78,22 +80,24 @@ def make_folder_dataset(
samples = samples.sort_values(by="image_path", ignore_index=True)

# Create label index for normal (0) and abnormal (1) images.
samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0
samples.loc[(samples.label == "abnormal"), "label_index"] = 1
samples.loc[
(samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), "label_index"
] = LabelName.NORMAL
samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
samples.label_index = samples.label_index.astype("Int64")

# If a path to mask is provided, add it to the sample dataframe.

if mask_dir is not None and abnormal_dir is not None:
samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[
samples.label == "mask_dir"
samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
samples.label == DirType.MASK
].image_path.values
samples["mask_path"].fillna("", inplace=True)
samples = samples.astype({"mask_path": "str"})

# make sure all every rgb image has a corresponding mask image.
assert (
samples.loc[samples.label_index == 1]
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and mask images. Make sure the mask files \
Expand All @@ -104,7 +108,7 @@ def make_folder_dataset(

# remove all the rows with temporal image samples that have already been assigned
samples = samples.loc[
(samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test")
(samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
]

# Ensure the pathlib objects are converted to str.
Expand All @@ -114,8 +118,8 @@ def make_folder_dataset(
# Create train/test split.
# By default, all the normal samples are assigned as train.
# and all the abnormal samples are test.
samples.loc[(samples.label == "normal"), "split"] = "train"
samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test"
samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST

# Get the data frame for the split.
if split:
Expand Down
46 changes: 25 additions & 21 deletions src/anomalib/data/folder_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from anomalib.data.base import AnomalibDataModule, AnomalibDepthDataset
from anomalib.data.task_type import TaskType
from anomalib.data.utils import (
DirType,
InputNormalizationMethod,
LabelName,
Split,
TestSplitMode,
ValSplitMode,
Expand Down Expand Up @@ -75,25 +77,25 @@ def make_folder3d_dataset(

filenames = []
labels = []
dirs = {"normal": normal_dir}
dirs = {DirType.NORMAL: normal_dir}

if abnormal_dir:
dirs = {**dirs, **{"abnormal": abnormal_dir}}
dirs = {**dirs, **{DirType.ABNORMAL: abnormal_dir}}

if normal_test_dir:
dirs = {**dirs, **{"normal_test": normal_test_dir}}
dirs = {**dirs, **{DirType.NORMAL_TEST: normal_test_dir}}

if normal_depth_dir:
dirs = {**dirs, **{"normal_depth": normal_depth_dir}}
dirs = {**dirs, **{DirType.NORMAL_DEPTH: normal_depth_dir}}

if abnormal_depth_dir:
dirs = {**dirs, **{"abnormal_depth": abnormal_depth_dir}}
dirs = {**dirs, **{DirType.ABNORMAL_DEPTH: abnormal_depth_dir}}

if normal_test_depth_dir:
dirs = {**dirs, **{"normal_test_depth": normal_test_depth_dir}}
dirs = {**dirs, **{DirType.NORMAL_TEST_DEPTH: normal_test_depth_dir}}

if mask_dir:
dirs = {**dirs, **{"mask_dir": mask_dir}}
dirs = {**dirs, **{DirType.MASK: mask_dir}}

for dir_type, path in dirs.items():
filename, label = _prepare_files_labels(path, dir_type, extensions)
Expand All @@ -104,27 +106,29 @@ def make_folder3d_dataset(
samples = samples.sort_values(by="image_path", ignore_index=True)

# Create label index for normal (0) and abnormal (1) images.
samples.loc[(samples.label == "normal") | (samples.label == "normal_test"), "label_index"] = 0
samples.loc[(samples.label == "abnormal"), "label_index"] = 1
samples.loc[
(samples.label == DirType.NORMAL) | (samples.label == DirType.NORMAL_TEST), "label_index"
] = LabelName.NORMAL
samples.loc[(samples.label == DirType.ABNORMAL), "label_index"] = LabelName.ABNORMAL
samples.label_index = samples.label_index.astype("Int64")

# If a path to mask is provided, add it to the sample dataframe.
if normal_depth_dir is not None:
samples.loc[samples.label == "normal", "depth_path"] = samples.loc[
samples.label == "normal_depth"
samples.loc[samples.label == DirType.NORMAL, "depth_path"] = samples.loc[
samples.label == DirType.NORMAL_DEPTH
].image_path.values
samples.loc[samples.label == "abnormal", "depth_path"] = samples.loc[
samples.label == "abnormal_depth"
samples.loc[samples.label == DirType.ABNORMAL, "depth_path"] = samples.loc[
samples.label == DirType.ABNORMAL_DEPTH
].image_path.values

if normal_test_dir is not None:
samples.loc[samples.label == "normal_test", "depth_path"] = samples.loc[
samples.label == "normal_test_depth"
samples.loc[samples.label == DirType.NORMAL_TEST, "depth_path"] = samples.loc[
samples.label == DirType.NORMAL_TEST_DEPTH
].image_path.values

# make sure every rgb image has a corresponding depth image and that the file exists
assert (
samples.loc[samples.label_index == 1]
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \
Expand All @@ -139,8 +143,8 @@ def make_folder3d_dataset(

# If a path to mask is provided, add it to the sample dataframe.
if mask_dir is not None and abnormal_dir is not None:
samples.loc[samples.label == "abnormal", "mask_path"] = samples.loc[
samples.label == "mask_dir"
samples.loc[samples.label == DirType.ABNORMAL, "mask_path"] = samples.loc[
samples.label == DirType.MASK
].image_path.values
samples["mask_path"].fillna("", inplace=True)
samples = samples.astype({"mask_path": "str"})
Expand All @@ -154,7 +158,7 @@ def make_folder3d_dataset(

# remove all the rows with temporal image samples that have already been assigned
samples = samples.loc[
(samples.label == "normal") | (samples.label == "abnormal") | (samples.label == "normal_test")
(samples.label == DirType.NORMAL) | (samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST)
]

# Ensure the pathlib objects are converted to str.
Expand All @@ -164,8 +168,8 @@ def make_folder3d_dataset(
# Create train/test split.
# By default, all the normal samples are assigned as train.
# and all the abnormal samples are test.
samples.loc[(samples.label == "normal"), "split"] = "train"
samples.loc[(samples.label == "abnormal") | (samples.label == "normal_test"), "split"] = "test"
samples.loc[(samples.label == DirType.NORMAL), "split"] = Split.TRAIN
samples.loc[(samples.label == DirType.ABNORMAL) | (samples.label == DirType.NORMAL_TEST), "split"] = Split.TEST

# Get the data frame for the split.
if split:
Expand Down
11 changes: 7 additions & 4 deletions src/anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from anomalib.data.utils import (
DownloadInfo,
InputNormalizationMethod,
LabelName,
Split,
TestSplitMode,
ValSplitMode,
Expand Down Expand Up @@ -119,8 +120,8 @@ def make_mvtec_dataset(
samples["image_path"] = samples.path + "/" + samples.split + "/" + samples.label + "/" + samples.image_path

# Create label index for normal (0) and anomalous (1) images.
samples.loc[(samples.label == "good"), "label_index"] = 0
samples.loc[(samples.label != "good"), "label_index"] = 1
samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
samples.label_index = samples.label_index.astype(int)

# separate masks from samples
Expand All @@ -129,11 +130,13 @@ def make_mvtec_dataset(

# assign mask paths to anomalous test images
samples["mask_path"] = ""
samples.loc[(samples.split == "test") & (samples.label_index == 1), "mask_path"] = mask_samples.image_path.values
samples.loc[
(samples.split == "test") & (samples.label_index == LabelName.ABNORMAL), "mask_path"
] = mask_samples.image_path.values

# assert that the right mask files are associated with the right test images
assert (
samples.loc[samples.label_index == 1]
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \
Expand Down
9 changes: 5 additions & 4 deletions src/anomalib/data/mvtec_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from anomalib.data.utils import (
DownloadInfo,
InputNormalizationMethod,
LabelName,
Split,
TestSplitMode,
ValSplitMode,
Expand Down Expand Up @@ -135,8 +136,8 @@ def make_mvtec_3d_dataset(
)

# Create label index for normal (0) and anomalous (1) images.
samples.loc[(samples.label == "good"), "label_index"] = 0
samples.loc[(samples.label != "good"), "label_index"] = 1
samples.loc[(samples.label == "good"), "label_index"] = LabelName.NORMAL
samples.loc[(samples.label != "good"), "label_index"] = LabelName.ABNORMAL
samples.label_index = samples.label_index.astype(int)

# separate masks from samples
Expand All @@ -154,7 +155,7 @@ def make_mvtec_3d_dataset(

# assert that the right mask files are associated with the right test images
assert (
samples.loc[samples.label_index == 1]
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.mask_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and ground truth masks. Make sure the mask files in 'ground_truth' \
Expand All @@ -163,7 +164,7 @@ def make_mvtec_3d_dataset(

# assert that the right depth image files are associated with the right test images
assert (
samples.loc[samples.label_index == 1]
samples.loc[samples.label_index == LabelName.ABNORMAL]
.apply(lambda x: Path(x.image_path).stem in Path(x.depth_path).stem, axis=1)
.all()
), "Mismatch between anomalous images and depth images. Make sure the mask files in 'xyz' \
Expand Down
Loading

0 comments on commit 2661177

Please sign in to comment.