Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: fix random val/test split issue #48

Merged
merged 2 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]):
train_batch_size=config.dataset.train_batch_size,
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
)
else:
raise ValueError("Unknown dataset!")
Expand Down
76 changes: 42 additions & 34 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def make_mvtec_dataset(
samples.label_index = samples.label_index.astype(int)

if create_validation_set:
samples = create_validation_set_from_test_set(samples)
samples = create_validation_set_from_test_set(samples, seed=seed)

# Get the data frame for the split.
if split is not None and split in ["train", "test"]:
if split is not None and split in ["train", "val", "test"]:
samples = samples[samples.split == split]
samples = samples.reset_index(drop=True)

Expand All @@ -217,19 +217,23 @@ def __init__(
root: Union[Path, str],
category: str,
pre_process: PreProcessor,
split: str,
task: str = "segmentation",
is_train: bool = True,
download: bool = False,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
"""Mvtec Dataset class.

Args:
root: Path to the MVTec dataset
category: Name of the MVTec category.
pre_process: List of pre_processing object containing albumentation compose.
split: 'train', 'val' or 'test'
task: ``classification`` or ``segmentation``
is_train: Boolean to check if the split is training
download: Boolean to download the MVTec dataset.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets

Examples:
>>> from anomalib.data.mvtec import MVTec
Expand Down Expand Up @@ -264,15 +268,17 @@ def __init__(
super().__init__(root)
self.root = Path(root) if isinstance(root, str) else root
self.category: str = category
self.split = "train" if is_train else "test"
self.split = split
self.task = task

self.pre_process = pre_process

if download:
self._download()

self.samples = make_mvtec_dataset(path=self.root / category, split=self.split)
self.samples = make_mvtec_dataset(
path=self.root / category, split=self.split, seed=seed, create_validation_set=create_validation_set
)

def _download(self) -> None:
"""Download the MVTec dataset."""
Expand Down Expand Up @@ -327,8 +333,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
if self.split == "train" or self.task == "classification":
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}

if self.split == "test":
elif self.split in ["val", "test"]:
label_index = self.samples.label_index[index]

item["image_path"] = image_path
Expand Down Expand Up @@ -366,6 +371,8 @@ def __init__(
test_batch_size: int = 32,
num_workers: int = 8,
transform_config: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
create_validation_set: bool = False,
) -> None:
"""Mvtec Lightning Data Module.

Expand All @@ -377,6 +384,8 @@ def __init__(
test_batch_size: Testing batch size.
num_workers: Number of workers.
transform_config: Config for pre-processing.
seed: seed used for the random subset splitting
create_validation_set: Create a validation subset in addition to the train and test subsets

Examples
>>> from anomalib.data import MVTecDataModule
Expand Down Expand Up @@ -415,47 +424,45 @@ def __init__(
self.test_batch_size = test_batch_size
self.num_workers = num_workers

self.train_data: Dataset
self.val_data: Dataset
self.create_validation_set = create_validation_set
self.seed = seed

def prepare_data(self):
"""Prepare MVTec Dataset."""
# Train
MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=True,
download=True,
)

# Test
MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=False,
download=True,
)
self.train_data: Dataset
self.test_data: Dataset
if create_validation_set:
self.val_data: Dataset
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

def setup(self, stage: Optional[str] = None) -> None:
"""Setup train, validation and test data.

Args:
stage: Optional[str]: Train/Val/Test stages. (Default value = None)
"""
self.val_data = MVTec(
if self.create_validation_set:
self.val_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
split="val",
seed=self.seed,
create_validation_set=self.create_validation_set,
)
self.test_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=False,
split="test",
seed=self.seed,
create_validation_set=self.create_validation_set,
)
if stage in (None, "fit"):
self.train_data = MVTec(
root=self.root,
category=self.category,
pre_process=self.pre_process,
is_train=True,
split="train",
seed=self.seed,
create_validation_set=self.create_validation_set,
)

def train_dataloader(self) -> DataLoader:
Expand All @@ -464,8 +471,9 @@ def train_dataloader(self) -> DataLoader:

def val_dataloader(self) -> DataLoader:
"""Get validation dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
dataset = self.val_data if self.create_validation_set else self.test_data
return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)

def test_dataloader(self) -> DataLoader:
"""Get test dataloader."""
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)