Skip to content

Commit

Permalink
Support null seed (#437)
Browse files Browse the repository at this point in the history
* Support null seed

* Set optional value to None

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 and Ashwin Vaidya committed Jul 15, 2022
1 parent 8f437b7 commit 2f9fa13
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 16 deletions.
15 changes: 12 additions & 3 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
import shutil
import warnings
import zipfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -56,7 +57,7 @@ def make_btech_dataset(
path: Path,
split: Optional[str] = None,
split_ratio: float = 0.1,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> DataFrame:
"""Create BTech samples by parsing the BTech data file structure.
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Btech Dataset class.
Expand Down Expand Up @@ -197,6 +198,14 @@ def __init__(
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
"""
super().__init__(root)

if seed is None:
warnings.warn(
"seed is None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = split
Expand Down Expand Up @@ -274,7 +283,7 @@ def __init__(
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Instantiate BTech Lightning Data Module.
Expand Down
13 changes: 10 additions & 3 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def make_dataset(
mask_dir: Optional[Union[str, Path]] = None,
split: Optional[str] = None,
split_ratio: float = 0.2,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = True,
extensions: Optional[Tuple[str, ...]] = None,
):
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(
mask_dir: Optional[Union[Path, str]] = None,
extensions: Optional[Tuple[str, ...]] = None,
task: Optional[str] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Create Folder Folder Dataset.
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
mask_dir: Optional[Union[Path, str]] = None,
extensions: Optional[Tuple[str, ...]] = None,
split_ratio: float = 0.2,
seed: int = 0,
seed: Optional[int] = None,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
train_batch_size: int = 32,
test_batch_size: int = 32,
Expand Down Expand Up @@ -425,6 +425,13 @@ def __init__(
"""
super().__init__()

if seed is None and normal_test_dir is None:
raise ValueError(
"Both seed and normal_test_dir cannot be None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = _check_and_convert_path(root)
self.normal_dir = self.root / normal_dir
self.abnormal_dir = self.root / abnormal_dir
Expand Down
15 changes: 12 additions & 3 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import logging
import tarfile
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from urllib.request import urlretrieve
Expand Down Expand Up @@ -72,7 +73,7 @@ def make_mvtec_dataset(
path: Path,
split: Optional[str] = None,
split_ratio: float = 0.1,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> DataFrame:
"""Create MVTec AD samples by parsing the MVTec AD data file structure.
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Mvtec AD Dataset class.
Expand Down Expand Up @@ -220,6 +221,14 @@ def __init__(
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
"""
super().__init__(root)

if seed is None:
warnings.warn(
"seed is None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = split
Expand Down Expand Up @@ -297,7 +306,7 @@ def __init__(
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Mvtec AD Lightning Data Module.
Expand Down
11 changes: 7 additions & 4 deletions anomalib/data/utils/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
# and limitations under the License.

import random
from typing import Optional

from pandas.core.frame import DataFrame


def split_normal_images_in_train_set(
samples: DataFrame, split_ratio: float = 0.1, seed: int = 0, normal_label: str = "good"
samples: DataFrame, split_ratio: float = 0.1, seed: Optional[int] = None, normal_label: str = "good"
) -> DataFrame:
"""Split normal images in train set.
Expand All @@ -49,7 +50,7 @@ def split_normal_images_in_train_set(
DataFrame: Output dataframe where the part of the training set is assigned to test set.
"""

if seed >= 0:
if seed:
random.seed(seed)

normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == normal_label)].to_list()
Expand All @@ -62,7 +63,9 @@ def split_normal_images_in_train_set(
return samples


def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, normal_label: str = "good") -> DataFrame:
def create_validation_set_from_test_set(
samples: DataFrame, seed: Optional[int] = None, normal_label: str = "good"
) -> DataFrame:
"""Craete Validation Set from Test Set.
This function creates a validation set from test set by splitting both
Expand All @@ -74,7 +77,7 @@ def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, norma
normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label.
"""

if seed >= 0:
if seed:
random.seed(seed)

# Split normal images.
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
max_size: Optional[int] = 10,
train_shapes: List[str] = ["triangle", "rectangle"],
test_shapes: List[str] = ["star", "hexagon"],
seed: int = 0,
seed: Optional[int] = None,
) -> None:
self.root_dir = mkdtemp()
self.num_train = num_train
Expand Down Expand Up @@ -244,7 +244,7 @@ def _generate_dataset(self):

def __enter__(self):
"""Creates the dataset in temp folder."""
if self.seed > 0:
if self.seed:
np.random.seed(self.seed)
self._generate_dataset()
return self.root_dir
Expand Down
33 changes: 33 additions & 0 deletions tests/pre_merge/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, mvtec_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, mvtec_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(mvtec_data_module.test_data.samples["image_path"].values).intersection(
set(mvtec_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestBTechDataModule:
"""Test BTech Data Module."""
Expand All @@ -111,6 +122,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, btech_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, btech_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(btech_data_module.test_data.samples["image_path"].values).intersection(
set(btech_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestFolderDataModule:
"""Test Folder Data Module."""
Expand All @@ -130,6 +152,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, folder_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, folder_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(folder_data_module.test_data.samples["image_path"].values).intersection(
set(folder_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestDenormalize:
"""Test Denormalize Util."""
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train():
warnings.filterwarnings("ignore")

config = get_configurable_parameters(model_name=args.model, config_path=args.config)
if config.project.seed != 0:
if config.project.seed:
seed_everything(config.project.seed)

datamodule = get_datamodule(config)
Expand Down

0 comments on commit 2f9fa13

Please sign in to comment.