Skip to content

Commit

Permalink
Address compatibility issues with OTE (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Dec 14, 2021
1 parent cf8bdf6 commit e0345ee
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 14 deletions.
8 changes: 6 additions & 2 deletions anomalib/core/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def pre_process(self, image: np.ndarray) -> Tensor:
Returns:
Tensor: pre-processed image.
"""
pre_processor = PreProcessor(config=self.config.transform, to_tensor=True)
config = self.config.transform if "transform" in self.config.keys() else None
image_size = tuple(self.config.dataset.image_size)
pre_processor = PreProcessor(config, image_size)
processed_image = pre_processor(image=image)["image"]

if len(processed_image) == 3:
Expand Down Expand Up @@ -236,7 +238,9 @@ def pre_process(self, image: np.ndarray) -> np.ndarray:
Returns:
np.ndarray: pre-processed image.
"""
pre_processor = PreProcessor(config=self.config.transform, to_tensor=False)
config = self.config.transform if "transform" in self.config.keys() else None
image_size = tuple(self.config.dataset.image_size)
pre_processor = PreProcessor(config, image_size)
processed_image = pre_processor(image=image)["image"]

if len(processed_image.shape) == 3:
Expand Down
71 changes: 60 additions & 11 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,14 @@


def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.1, seed: int = 0) -> DataFrame:
"""This function splits the normal images in training set and assigns the values to the test set.
"""Split normal images in train set.
This is particularly useful especially when the test set does not contain any normal images.
This is important because when the test set doesn't have any normal images,
AUC computation fails due to having single class.
This function splits the normal images in training set and assigns the
values to the test set. This is particularly useful especially when the
test set does not contain any normal images.
This is important because when the test set doesn't have any normal images,
AUC computation fails due to having single class.
Args:
samples (DataFrame): Dataframe containing dataset info such as filenames, splits etc.
Expand All @@ -64,7 +67,9 @@ def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.
Returns:
DataFrame: Output dataframe where the part of the training set is assigned to test set.
"""
random.seed(seed)

if seed > 0:
random.seed(seed)

normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == "good")].to_list()
num_normal_train_images = len(normal_train_image_indices)
Expand All @@ -76,7 +81,44 @@ def split_normal_images_in_train_set(samples: DataFrame, split_ratio: float = 0.
return samples


def make_mvtec_dataset(path: Path, split: str = "train", split_ratio: float = 0.1, seed: int = 0) -> DataFrame:
def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0) -> DataFrame:
"""Craete Validation Set from Test Set.
This function creates a validation set from test set by splitting both
normal and abnormal samples to two.
Args:
samples (DataFrame): Dataframe containing dataset info such as filenames, splits etc.
seed (int, optional): Random seed to ensure reproducibility. Defaults to 0.
"""

if seed > 0:
random.seed(seed)

# Split normal images.
normal_test_image_indices = samples.index[(samples.split == "test") & (samples.label == "good")].to_list()
num_normal_valid_images = len(normal_test_image_indices) // 2

indices_to_sample = random.sample(population=normal_test_image_indices, k=num_normal_valid_images)
samples.loc[indices_to_sample, "split"] = "val"

# Split abnormal images.
abnormal_test_image_indices = samples.index[(samples.split == "test") & (samples.label != "good")].to_list()
num_abnormal_valid_images = len(abnormal_test_image_indices) // 2

indices_to_sample = random.sample(population=abnormal_test_image_indices, k=num_abnormal_valid_images)
samples.loc[indices_to_sample, "split"] = "val"

return samples


def make_mvtec_dataset(
path: Path,
split: Optional[str] = None,
split_ratio: float = 0.1,
seed: int = 0,
create_validation_set: bool = False,
) -> DataFrame:
"""Create MVTec samples by parsing the MVTec data file structure.
The files are expected to follow the structure:
Expand All @@ -92,11 +134,14 @@ def make_mvtec_dataset(path: Path, split: str = "train", split_ratio: float = 0.
Args:
path (Path): Path to dataset
split (str, optional): Dataset split (ie., either train or test). Defaults to "train".
split (str, optional): Dataset split (ie., either train or test). Defaults to None.
split_ratio (float, optional): Ratio to split normal training images and add to the
test set in case test set doesn't contain any normal images.
Defaults to 0.1.
test set in case test set doesn't contain any normal images.
Defaults to 0.1.
seed (int, optional): Random seed to ensure reproducibility when splitting. Defaults to 0.
create_validation_set (bool, optional): Boolean to create a validation set from the test set.
MVTec dataset does not contain a validation set. Those wanting to create a validation set
could set this flag to ``True``.
Example:
The following example shows how to get training samples from MVTec bottle category:
Expand Down Expand Up @@ -153,9 +198,13 @@ def make_mvtec_dataset(path: Path, split: str = "train", split_ratio: float = 0.
samples.loc[(samples.label != "good"), "label_index"] = 1
samples.label_index = samples.label_index.astype(int)

if create_validation_set:
samples = create_validation_set_from_test_set(samples)

# Get the data frame for the split.
samples = samples[samples.split == split]
samples = samples.reset_index(drop=True)
if split is not None and split in ["train", "test"]:
samples = samples[samples.split == split]
samples = samples.reset_index(drop=True)

return samples

Expand Down
2 changes: 1 addition & 1 deletion anomalib/data/transforms/pre_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class PreProcessor:
def __init__(
self,
config: Optional[Union[str, A.Compose]] = None,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
image_size: Optional[Union[int, Tuple]] = None,
to_tensor: bool = True,
) -> None:
self.config = config
Expand Down

0 comments on commit e0345ee

Please sign in to comment.