Skip to content

Commit

Permalink
Bugfix: fix random val/test split issue (#48)
Browse files Browse the repository at this point in the history
* fix random val/test split issue

Co-authored-by: Samet <samet.akcay@intel.com>
  • Loading branch information
djdameln and samet-akcay committed Dec 23, 2021
1 parent 6eadef9 commit 0d5d26d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
1 change: 1 addition & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]):
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
)
else:
raise ValueError("Unknown dataset!")
Expand Down
76 changes: 42 additions & 34 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def make_mvtec_dataset(
samples.label_index = samples.label_index.astype(int)

if create_validation_set:
samples = create_validation_set_from_test_set(samples)
samples = create_validation_set_from_test_set(samples, seed=seed)

# Get the data frame for the split.
if split is not None and split in ["train", "test"]:
if split is not None and split in ["train", "val", "test"]:
samples = samples[samples.split == split]
samples = samples.reset_index(drop=True)

Expand All @@ -217,19 +217,23 @@ def __init__(
root: Union[Path, str],
category: str,
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
is_train: bool = True,
download: bool = False,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
"""Mvtec Dataset class.
Args:
root: Path to the MVTec dataset
category: Name of the MVTec category.
pre_process: List of pre_processing object containing albumentation compose.
split: 'train', 'val' or 'test'
task: ``classification`` or ``segmentation``
is_train: Boolean to check if the split is training
download: Boolean to download the MVTec dataset.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets
Examples:
>>> from anomalib.data.mvtec import MVTec
Expand Down Expand Up @@ -264,15 +268,17 @@ def __init__(
super().__init__(root)
self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = "train" if is_train else "test"
self.split = split
self.task = task

self.pre_process = pre_process

if download:
self._download()

self.samples = make_mvtec_dataset(path=self.root / category, split=self.split)
self.samples = make_mvtec_dataset(
path=self.root / category, split=self.split, seed=seed, create_validation_set=create_validation_set
)

def _download(self) -> None:
"""Download the MVTec dataset."""
Expand Down Expand Up @@ -327,8 +333,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
if self.split == "train" or self.task == "classification":
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}

if self.split == "test":
elif self.split in ["val", "test"]:
label_index = self.samples.label_index[index]

item["image_path"] = image_path
Expand Down Expand Up @@ -366,6 +371,8 @@ def __init__(
test_batch_size: int = 32,
num_workers: int = 8,
transform_config: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
"""Mvtec Lightning Data Module.
Expand All @@ -377,6 +384,8 @@ def __init__(
test_batch_size: Testing batch size.
num_workers: Number of workers.
transform_config: Config for pre-processing.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets
Examples
>>> from anomalib.data import MVTecDataModule
Expand Down Expand Up @@ -415,47 +424,45 @@ def __init__(
self.test_batch_size = test_batch_size
self.num_workers = num_workers

self.train_data: Dataset
self.val_data: Dataset
self.create_validation_set = create_validation_set
self.seed = seed

def prepare_data(self):
"""Prepare MVTec Dataset."""
# Train
MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=True,
download=True,
)

# Test
MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=False,
download=True,
)
self.train_data: Dataset
self.test_data: Dataset
if create_validation_set:
self.val_data: Dataset

def setup(self, stage: Optional[str] = None) -> None:
"""Setup train, validation and test data.
Args:
stage: Optional[str]: Train/Val/Test stages. (Default value = None)
"""
self.val_data = MVTec(
if self.create_validation_set:
self.val_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
split="val",
seed=self.seed,
create_validation_set=self.create_validation_set,
)
self.test_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=False,
split="test",
seed=self.seed,
create_validation_set=self.create_validation_set,
)
if stage in (None, "fit"):
self.train_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=True,
split="train",
seed=self.seed,
create_validation_set=self.create_validation_set,
)

def train_dataloader(self) -> DataLoader:
Expand All @@ -464,8 +471,9 @@ def train_dataloader(self) -> DataLoader:

def val_dataloader(self) -> DataLoader:
"""Get validation dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
dataset = self.val_data if self.create_validation_set else self.test_data
return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)

def test_dataloader(self) -> DataLoader:
"""Get test dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)

0 comments on commit 0d5d26d

Please sign in to comment.