Skip to content

Commit

Permalink
Return only image, path and label for classification tasks in `…
Browse files Browse the repository at this point in the history
…Mvtec` and `Btech` datasets. (#196)

* Add task option to btech dataset.

* Add task option to mvtec dataset.

* Add task option to btech mvtec datamodules.

* ✏️ Change default task from classification to segmentation in Mvtec
Btech datamodules.

* ✏️ Fix typo. Mvtec ➡️ Btech

* ✏️ Fix typo. Btech ➡️ BTech
  • Loading branch information
samet-akcay committed Apr 7, 2022
1 parent 97734af commit 9ae6946
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
Expand All @@ -60,6 +61,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
test_batch_size=config.dataset.test_batch_size,
num_workers=config.dataset.num_workers,
seed=config.project.seed,
task=config.dataset.task,
transform_config_train=config.dataset.transform_config.train,
transform_config_val=config.dataset.transform_config.val,
create_validation_set=config.dataset.create_validation_set,
Expand Down
14 changes: 10 additions & 4 deletions anomalib/data/btech.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
image_path = self.samples.image_path[index]
image = read_image(image_path)

if self.split == "train" or self.task == "classification":
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}
elif self.split in ["val", "test"]:
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}

if self.split in ["val", "test"]:
label_index = self.samples.label_index[index]

item["image_path"] = image_path
Expand Down Expand Up @@ -270,6 +270,7 @@ def __init__(
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
Expand All @@ -284,6 +285,7 @@ def __init__(
train_batch_size: Training batch size.
test_batch_size: Testing batch size.
num_workers: Number of workers.
task: ``classification`` or ``segmentation``
transform_config_train: Config for pre-processing during training.
transform_config_val: Config for pre-processing during validation.
seed: seed used for the random subset splitting
Expand Down Expand Up @@ -335,6 +337,7 @@ def __init__(
self.num_workers = num_workers

self.create_validation_set = create_validation_set
self.task = task
self.seed = seed

self.train_data: Dataset
Expand Down Expand Up @@ -399,6 +402,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_train,
split="train",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand All @@ -409,6 +413,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_val,
split="val",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand All @@ -418,6 +423,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_val,
split="test",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand Down
14 changes: 10 additions & 4 deletions anomalib/data/mvtec.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
image_path = self.samples.image_path[index]
image = read_image(image_path)

if self.split == "train" or self.task == "classification":
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}
elif self.split in ["val", "test"]:
pre_processed = self.pre_process(image=image)
item = {"image": pre_processed["image"]}

if self.split in ["val", "test"]:
label_index = self.samples.label_index[index]

item["image_path"] = image_path
Expand Down Expand Up @@ -293,6 +293,7 @@ def __init__(
train_batch_size: int = 32,
test_batch_size: int = 32,
num_workers: int = 8,
task: str = "segmentation",
transform_config_train: Optional[Union[str, A.Compose]] = None,
transform_config_val: Optional[Union[str, A.Compose]] = None,
seed: int = 0,
Expand All @@ -307,6 +308,7 @@ def __init__(
train_batch_size: Training batch size.
test_batch_size: Testing batch size.
num_workers: Number of workers.
task: ``classification`` or ``segmentation``
transform_config_train: Config for pre-processing during training.
transform_config_val: Config for pre-processing during validation.
seed: seed used for the random subset splitting
Expand Down Expand Up @@ -358,6 +360,7 @@ def __init__(
self.num_workers = num_workers

self.create_validation_set = create_validation_set
self.task = task
self.seed = seed

self.train_data: Dataset
Expand Down Expand Up @@ -402,6 +405,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_train,
split="train",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand All @@ -412,6 +416,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_val,
split="val",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand All @@ -421,6 +426,7 @@ def setup(self, stage: Optional[str] = None) -> None:
category=self.category,
pre_process=self.pre_process_val,
split="test",
task=self.task,
seed=self.seed,
create_validation_set=self.create_validation_set,
)
Expand Down

0 comments on commit 9ae6946

Please sign in to comment.