diff --git a/anomalib/data/utils/__init__.py b/anomalib/data/utils/__init__.py index 53cf04e4fd..5059b51c06 100644 --- a/anomalib/data/utils/__init__.py +++ b/anomalib/data/utils/__init__.py @@ -5,11 +5,17 @@ from .download import DownloadProgressBar, hash_check from .generators import random_2d_perlin -from .image import generate_output_image_filename, get_image_filenames, read_image +from .image import ( + generate_output_image_filename, + get_image_filenames, + get_image_height_and_width, + read_image, +) __all__ = [ "generate_output_image_filename", "get_image_filenames", + "get_image_height_and_width", "hash_check", "random_2d_perlin", "read_image", diff --git a/anomalib/data/utils/image.py b/anomalib/data/utils/image.py index c8c5039882..758124f42e 100644 --- a/anomalib/data/utils/image.py +++ b/anomalib/data/utils/image.py @@ -6,7 +6,7 @@ import math import warnings from pathlib import Path -from typing import List, Union +from typing import List, Optional, Tuple, Union import cv2 import numpy as np @@ -141,7 +141,48 @@ def generate_output_image_filename(input_path: Union[str, Path], output_path: Un return file_path -def read_image(path: Union[str, Path]) -> np.ndarray: +def get_image_height_and_width(image_size: Optional[Union[int, Tuple]] = None) -> Tuple[Optional[int], Optional[int]]: + """Get image height and width from ``image_size`` variable. + + Args: + image_size (Optional[Union[int, Tuple[int, int]]], optional): Input image size. + + Raises: + ValueError: Image size not None, int or tuple. + + Examples: + >>> get_image_height_and_width(image_size=256) + (256, 256) + + >>> get_image_height_and_width(image_size=(256, 256)) + (256, 256) + + >>> get_image_height_and_width(image_size=(256, 256, 3)) + (256, 256) + + >>> get_image_height_and_width(image_size=256.) + Traceback (most recent call last): + File "", line 1, in + File "", line 18, in get_image_height_and_width + ValueError: ``image_size`` could be either int or Tuple[int, int] + + Returns: + Tuple[Optional[int], Optional[int]]: A tuple containing image height and width values. + """ + height_and_width: Tuple[Optional[int], Optional[int]] + if isinstance(image_size, int): + height_and_width = (image_size, image_size) + elif isinstance(image_size, tuple): + height_and_width = int(image_size[0]), int(image_size[1]) + elif image_size is None: + height_and_width = (None, None) + else: + raise ValueError("``image_size`` could be either int or Tuple[int, int]") + + return height_and_width + + +def read_image(path: Union[str, Path], image_size: Optional[Union[int, Tuple]] = None) -> np.ndarray: """Read image from disk in RGB format. Args: @@ -157,6 +198,13 @@ def read_image(path: Union[str, Path]) -> np.ndarray: image = cv2.imread(path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + if image_size: + # This part is optional, where the user wants to quickly resize the image + # with a one-liner code. This would particularly be useful especially when + # prototyping new ideas. + height, width = get_image_height_and_width(image_size) + image = cv2.resize(image, dsize=(width, height), interpolation=cv2.INTER_AREA) + return image diff --git a/anomalib/post_processing/visualizer.py b/anomalib/post_processing/visualizer.py index 409a52b4a2..19b65887a7 100644 --- a/anomalib/post_processing/visualizer.py +++ b/anomalib/post_processing/visualizer.py @@ -13,12 +13,12 @@ import numpy as np from skimage.segmentation import mark_boundaries +from anomalib.data.utils import read_image from anomalib.post_processing.post_process import ( add_anomalous_label, add_normal_label, superimpose_anomaly_map, ) -from anomalib.pre_processing.transforms import Denormalize @dataclass @@ -73,9 +73,10 @@ def visualize_batch(self, batch: Dict) -> Iterator[np.ndarray]: Returns: Generator that yields a display-ready visualization for each image. """ - for i in range(batch["image"].size(0)): + batch_size, _num_channels, height, width = batch["image"].size() + for i in range(batch_size): image_result = ImageResult( - image=Denormalize()(batch["image"][i].cpu()), + image=read_image(path=batch["image_path"][i], image_size=(height, width)), pred_score=batch["pred_scores"][i].cpu().numpy().item(), pred_label=batch["pred_labels"][i].cpu().numpy().item(), anomaly_map=batch["anomaly_maps"][i].cpu().numpy() if "anomaly_maps" in batch else None, diff --git a/anomalib/pre_processing/pre_process.py b/anomalib/pre_processing/pre_process.py index 28740496eb..44cbc294e4 100644 --- a/anomalib/pre_processing/pre_process.py +++ b/anomalib/pre_processing/pre_process.py @@ -7,11 +7,111 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import logging from typing import Optional, Tuple, Union import albumentations as A from albumentations.pytorch import ToTensorV2 +from anomalib.data.utils import get_image_height_and_width + +logger = logging.getLogger(__name__) + + +def get_transforms( + config: Optional[Union[str, A.Compose]] = None, + image_size: Optional[Union[int, Tuple]] = None, + to_tensor: bool = True, +) -> A.Compose: + """Get transforms from config or image size. + + Args: + config (Optional[Union[str, A.Compose]], optional): Albumentations transforms. + Either config or albumentations ``Compose`` object. Defaults to None. + image_size (Optional[Union[int, Tuple]], optional): Image size to transform. Defaults to None. + to_tensor (bool, optional): Boolean to convert the final transforms into Torch tensor. Defaults to True. + + Raises: + ValueError: When both ``config`` and ``image_size`` is ``None``. + ValueError: When ``config`` is not a ``str`` or `A.Compose`` object. + + Returns: + A.Compose: Albumentation ``Compose`` object containing the image transforms. + + Examples: + >>> import skimage + >>> image = skimage.data.astronaut() + + >>> transforms = get_transforms(image_size=256, to_tensor=False) + >>> output = transforms(image=image) + >>> output["image"].shape + (256, 256, 3) + + >>> transforms = get_transforms(image_size=256, to_tensor=True) + >>> output = transforms(image=image) + >>> output["image"].shape + torch.Size([3, 256, 256]) + + + Transforms could be read from albumentations Compose object. + >>> import albumentations as A + >>> from albumentations.pytorch import ToTensorV2 + >>> config = A.Compose([A.Resize(512, 512), ToTensorV2()]) + >>> transforms = get_transforms(config=config, to_tensor=False) + >>> output = transforms(image=image) + >>> output["image"].shape + (512, 512, 3) + >>> type(output["image"]) + numpy.ndarray + + Transforms could be deserialized from a yaml file. + >>> transforms = A.Compose([A.Resize(1024, 1024), ToTensorV2()]) + >>> A.save(transforms, "/tmp/transforms.yaml", data_format="yaml") + >>> transforms = get_transforms(config="/tmp/transforms.yaml") + >>> output = transforms(image=image) + >>> output["image"].shape + torch.Size([3, 1024, 1024]) + """ + if config is None and image_size is None: + raise ValueError( + "Both config and image_size cannot be `None`. " + "Provide either config file to de-serialize transforms " + "or image_size to get the default transformations" + ) + + transforms: A.Compose + + if config is None and image_size is not None: + logger.warning("Transform configs has not been provided. Images will be normalized using ImageNet statistics.") + + height, width = get_image_height_and_width(image_size) + transforms = A.Compose( + [ + A.Resize(height=height, width=width, always_apply=True), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2(), + ] + ) + + if config is not None: + if isinstance(config, str): + transforms = A.load(filepath=config, data_format="yaml") + elif isinstance(config, A.Compose): + transforms = config + else: + raise ValueError("config could be either ``str`` or ``A.Compose``") + + if not to_tensor: + if isinstance(transforms[-1], ToTensorV2): + transforms = A.Compose(transforms[:-1]) + + # always resize to specified image size + if not any(isinstance(transform, A.Resize) for transform in transforms) and image_size is not None: + height, width = get_image_height_and_width(image_size) + transforms = A.Compose([A.Resize(height=height, width=width, always_apply=True), transforms]) + + return transforms + class PreProcessor: """Applies pre-processing and data augmentations to the input and returns the transformed output. @@ -74,63 +174,8 @@ def __init__( self.image_size = image_size self.to_tensor = to_tensor - self.transforms = self.get_transforms() - - def get_transforms(self) -> A.Compose: - """Get transforms from config or image size. - - Returns: - A.Compose: List of albumentation transformations to apply to the - input image. - """ - if self.config is None and self.image_size is None: - raise ValueError( - "Both config and image_size cannot be `None`. " - "Provide either config file to de-serialize transforms " - "or image_size to get the default transformations" - ) - - transforms: A.Compose - - if self.config is None and self.image_size is not None: - height, width = self._get_height_and_width() - transforms = A.Compose( - [ - A.Resize(height=height, width=width, always_apply=True), - A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ToTensorV2(), - ] - ) - - if self.config is not None: - if isinstance(self.config, str): - transforms = A.load(filepath=self.config, data_format="yaml") - elif isinstance(self.config, A.Compose): - transforms = self.config - else: - raise ValueError("config could be either ``str`` or ``A.Compose``") - - if not self.to_tensor: - if isinstance(transforms[-1], ToTensorV2): - transforms = A.Compose(transforms[:-1]) - - # always resize to specified image size - if not any(isinstance(transform, A.Resize) for transform in transforms) and self.image_size is not None: - height, width = self._get_height_and_width() - transforms = A.Compose([A.Resize(height=height, width=width, always_apply=True), transforms]) - - return transforms + self.transforms = get_transforms(config, image_size, to_tensor) def __call__(self, *args, **kwargs): """Return transformed arguments.""" return self.transforms(*args, **kwargs) - - def _get_height_and_width(self) -> Tuple[Optional[int], Optional[int]]: - """Extract height and width from image size attribute.""" - if isinstance(self.image_size, int): - return self.image_size, self.image_size - if isinstance(self.image_size, tuple): - return int(self.image_size[0]), int(self.image_size[1]) - if self.image_size is None: - return None, None - raise ValueError("``image_size`` could be either int or Tuple[int, int]") diff --git a/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py b/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py index 6072852fbe..8644871661 100644 --- a/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py +++ b/tests/pre_merge/utils/callbacks/visualizer_callback/dummy_lightning_model.py @@ -11,6 +11,7 @@ from anomalib.models.components import AnomalyModule from anomalib.utils.callbacks import ImageVisualizerCallback from anomalib.utils.metrics import get_metrics +from tests.helpers.dataset import get_dataset_path class DummyDataset(Dataset): @@ -68,7 +69,7 @@ def test_step(self, batch, _): """Only used to trigger on_test_epoch_end.""" self.log(name="loss", value=0.0, prog_bar=True) outputs = dict( - image_path=[Path("test1.jpg")], + image_path=[Path(get_dataset_path("bottle")) / "broken_large/000.png"], image=torch.rand((1, 3, 100, 100)), mask=torch.zeros((1, 100, 100)), anomaly_maps=torch.ones((1, 100, 100)), diff --git a/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py b/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py index 6ecd92c6f0..35fdb2e129 100644 --- a/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py +++ b/tests/pre_merge/utils/callbacks/visualizer_callback/test_visualizer.py @@ -1,7 +1,7 @@ import glob import os import tempfile -from unittest import mock +from pathlib import Path import pytest import pytorch_lightning as pl @@ -42,7 +42,7 @@ def test_add_images(dataset): ) trainer.test(model=model, datamodule=DummyDataModule()) # test if images are logged - if len(glob.glob(os.path.join(dir_loc, "images", "*.jpg"))) != 1: + if len(list(Path(dir_loc).glob("**/*.png"))) != 1: raise Exception("Failed to save to local path") # test if tensorboard logs are created