Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support null seed #437

Merged
merged 3 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import logging
import shutil
import warnings
import zipfile
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -56,7 +57,7 @@ def make_btech_dataset(
path: Path,
split: Optional[str] = None,
split_ratio: float = 0.1,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> DataFrame:
"""Create BTech samples by parsing the BTech data file structure.
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Btech Dataset class.
Expand Down Expand Up @@ -197,6 +198,14 @@ def __init__(
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
"""
super().__init__(root)

if seed is None:
warnings.warn(
"seed is None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = split
Expand Down Expand Up @@ -274,7 +283,7 @@ def __init__(
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Instantiate BTech Lightning Data Module.
Expand Down
13 changes: 10 additions & 3 deletions anomalib/data/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def make_dataset(
mask_dir: Optional[Union[str, Path]] = None,
split: Optional[str] = None,
split_ratio: float = 0.2,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = True,
extensions: Optional[Tuple[str, ...]] = None,
):
Expand Down Expand Up @@ -191,7 +191,7 @@ def __init__(
mask_dir: Optional[Union[Path, str]] = None,
extensions: Optional[Tuple[str, ...]] = None,
task: Optional[str] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Create Folder Folder Dataset.
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
mask_dir: Optional[Union[Path, str]] = None,
extensions: Optional[Tuple[str, ...]] = None,
split_ratio: float = 0.2,
seed: int = 0,
seed: Optional[int] = None,
image_size: Optional[Union[int, Tuple[int, int]]] = None,
train_batch_size: int = 32,
test_batch_size: int = 32,
Expand Down Expand Up @@ -425,6 +425,13 @@ def __init__(
"""
super().__init__()

if seed is None and normal_test_dir is None:
raise ValueError(
"Both seed and normal_test_dir cannot be None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = _check_and_convert_path(root)
self.normal_dir = self.root / normal_dir
self.abnormal_dir = self.root / abnormal_dir
Expand Down
15 changes: 12 additions & 3 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

import logging
import tarfile
import warnings
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
from urllib.request import urlretrieve
Expand Down Expand Up @@ -72,7 +73,7 @@ def make_mvtec_dataset(
path: Path,
split: Optional[str] = None,
split_ratio: float = 0.1,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> DataFrame:
"""Create MVTec AD samples by parsing the MVTec AD data file structure.
Expand Down Expand Up @@ -175,7 +176,7 @@ def __init__(
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Mvtec AD Dataset class.
Expand Down Expand Up @@ -220,6 +221,14 @@ def __init__(
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
"""
super().__init__(root)

if seed is None:
warnings.warn(
"seed is None."
" When seed is not set, images from the normal directory are split between training and test dir."
" This will lead to inconsistency between runs."
)

self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = split
Expand Down Expand Up @@ -297,7 +306,7 @@ def __init__(
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
seed: Optional[int] = None,
create_validation_set: bool = False,
) -> None:
"""Mvtec AD Lightning Data Module.
Expand Down
11 changes: 7 additions & 4 deletions anomalib/data/utils/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
# and limitations under the License.

import random
from typing import Optional

from pandas.core.frame import DataFrame


def split_normal_images_in_train_set(
samples: DataFrame, split_ratio: float = 0.1, seed: int = 0, normal_label: str = "good"
samples: DataFrame, split_ratio: float = 0.1, seed: Optional[int] = None, normal_label: str = "good"
) -> DataFrame:
"""Split normal images in train set.

Expand All @@ -49,7 +50,7 @@ def split_normal_images_in_train_set(
DataFrame: Output dataframe where the part of the training set is assigned to test set.
"""

if seed >= 0:
if seed:
random.seed(seed)

normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == normal_label)].to_list()
Expand All @@ -62,7 +63,9 @@ def split_normal_images_in_train_set(
return samples


def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, normal_label: str = "good") -> DataFrame:
def create_validation_set_from_test_set(
samples: DataFrame, seed: Optional[int] = None, normal_label: str = "good"
) -> DataFrame:
"""Craete Validation Set from Test Set.

This function creates a validation set from test set by splitting both
Expand All @@ -74,7 +77,7 @@ def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, norma
normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label.
"""

if seed >= 0:
if seed:
random.seed(seed)

# Split normal images.
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def __init__(
max_size: Optional[int] = 10,
train_shapes: List[str] = ["triangle", "rectangle"],
test_shapes: List[str] = ["star", "hexagon"],
seed: int = 0,
seed: Optional[int] = None,
) -> None:
self.root_dir = mkdtemp()
self.num_train = num_train
Expand Down Expand Up @@ -244,7 +244,7 @@ def _generate_dataset(self):

def __enter__(self):
"""Creates the dataset in temp folder."""
if self.seed > 0:
if self.seed:
np.random.seed(self.seed)
self._generate_dataset()
return self.root_dir
Expand Down
33 changes: 33 additions & 0 deletions tests/pre_merge/datasets/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, mvtec_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, mvtec_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(mvtec_data_module.test_data.samples["image_path"].values).intersection(
set(mvtec_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestBTechDataModule:
"""Test BTech Data Module."""
Expand All @@ -111,6 +122,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, btech_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, btech_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(btech_data_module.test_data.samples["image_path"].values).intersection(
set(btech_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestFolderDataModule:
"""Test Folder Data Module."""
Expand All @@ -130,6 +152,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, folder_data_module):
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())

def test_non_overlapping_splits(self, folder_data_module):
"""This test ensures that the train and test splits generated are non-overlapping."""
assert (
len(
set(folder_data_module.test_data.samples["image_path"].values).intersection(
set(folder_data_module.train_data.samples["image_path"].values)
)
)
== 0
), "Found train and test split contamination"


class TestDenormalize:
"""Test Denormalize Util."""
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def train():
warnings.filterwarnings("ignore")

config = get_configurable_parameters(model_name=args.model, config_path=args.config)
if config.project.seed != 0:
if config.project.seed:
seed_everything(config.project.seed)

datamodule = get_datamodule(config)
Expand Down