Skip to content

Commit

Permalink
Add segmentation mask to inference output (#242)
Browse files Browse the repository at this point in the history
* Add segmentation mask to inference output

* fix mypy issue

* Address PR comments

* remove kwd

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
ashwinvaidya17 and samet-akcay committed Apr 22, 2022
1 parent ce279f9 commit cab7aa2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
47 changes: 41 additions & 6 deletions anomalib/deploy/inferencers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
from pathlib import Path
from typing import Dict, Optional, Tuple, Union, cast

import cv2
import numpy as np
from omegaconf import DictConfig, OmegaConf
from skimage.morphology import dilation
from skimage.segmentation import find_boundaries
from torch import Tensor

from anomalib.data.utils import read_image
from anomalib.post_processing import superimpose_anomaly_map
from anomalib.post_processing import compute_mask, superimpose_anomaly_map
from anomalib.post_processing.normalization.cdf import normalize as normalize_cdf
from anomalib.post_processing.normalization.cdf import standardize
from anomalib.post_processing.normalization.min_max import (
Expand Down Expand Up @@ -60,7 +63,11 @@ def post_process(
raise NotImplementedError

def predict(
self, image: Union[str, np.ndarray, Path], superimpose: bool = True, meta_data: Optional[dict] = None
self,
image: Union[str, np.ndarray, Path],
superimpose: bool = True,
meta_data: Optional[dict] = None,
overlay_mask: bool = False,
) -> Tuple[np.ndarray, float]:
"""Perform a prediction for a given input image.
Expand All @@ -74,6 +81,8 @@ def predict(
will be superimposed onto the original image. If false, `predict`
method will return the raw heatmap.
overlay_mask (bool): If this is set to True, output segmentation mask on top of image.
Returns:
np.ndarray: Output predictions to be visualized.
"""
Expand All @@ -83,18 +92,44 @@ def predict(
else:
meta_data = {}
if isinstance(image, (str, Path)):
image = read_image(image)
meta_data["image_shape"] = image.shape[:2]
image_arr: np.ndarray = read_image(image)
else: # image is already a numpy array. Kept for mypy compatibility.
image_arr = image
meta_data["image_shape"] = image_arr.shape[:2]

processed_image = self.pre_process(image)
processed_image = self.pre_process(image_arr)
predictions = self.forward(processed_image)
anomaly_map, pred_scores = self.post_process(predictions, meta_data=meta_data)

# Overlay segmentation mask using raw predictions
if overlay_mask and meta_data is not None:
image_arr = self._superimpose_segmentation_mask(meta_data, anomaly_map, image_arr)

if superimpose is True:
anomaly_map = superimpose_anomaly_map(anomaly_map, image)
anomaly_map = superimpose_anomaly_map(anomaly_map, image_arr)

return anomaly_map, pred_scores

def _superimpose_segmentation_mask(self, meta_data: dict, anomaly_map: np.ndarray, image: np.ndarray):
"""Superimpose segmentation mask on top of image.
Args:
meta_data (dict): Metadata of the image which contains the image size.
anomaly_map (np.ndarray): Anomaly map which is used to extract segmentation mask.
image (np.ndarray): Image on which segmentation mask is to be superimposed.
Returns:
np.ndarray: Image with segmentation mask superimposed.
"""
pred_mask = compute_mask(anomaly_map, 0.5) # assumes predictions are normalized.
image_height = meta_data["image_shape"][0]
image_width = meta_data["image_shape"][1]
pred_mask = cv2.resize(pred_mask, (image_width, image_height))
boundaries = find_boundaries(pred_mask)
outlines = dilation(boundaries, np.ones((7, 7)))
image[outlines] = [255, 0, 0]
return image

def __call__(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
"""Call predict on the Image.
Expand Down
16 changes: 12 additions & 4 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def get_args() -> Namespace:
parser.add_argument("--image_path", type=Path, required=True, help="Path to an image to infer.")
parser.add_argument("--save_path", type=Path, required=False, help="Path to save the output image.")
parser.add_argument("--meta_data", type=Path, required=False, help="Path to JSON file containing the metadata.")
parser.add_argument(
"--overlay_mask",
type=bool,
required=False,
default=False,
help="Overlay the segmentation mask on the image. It assumes that the task is segmentation.",
)

args = parser.parse_args()
if args.model_config_path is not None:
Expand Down Expand Up @@ -114,29 +121,30 @@ def stream() -> None:
if image.is_file() and image.suffix in (".jpg", ".png", ".jpeg"):
# Here save_path is assumed to be a directory. Image subdirectories are appended to the save_path.
save_path = Path(args.save_path / image.relative_to(args.image_path).parent) if args.save_path else None
infer(image, inferencer, save_path)
infer(image, inferencer, save_path, args.overlay_mask)
elif args.image_path.suffix in (".jpg", ".png", ".jpeg"):
infer(args.image_path, inferencer, args.save_path)
infer(args.image_path, inferencer, args.save_path, args.overlay_mask)
else:
raise ValueError(
f"Image extension is not supported. Supported extensions are .jpg, .png, .jpeg."
f" Got {args.image_path.suffix}"
)


def infer(image_path: Path, inferencer: Inferencer, save_path: Optional[Path] = None) -> None:
def infer(image_path: Path, inferencer: Inferencer, save_path: Optional[Path] = None, overlay: bool = False) -> None:
"""Perform inference on a single image.
Args:
image_path (Path): Path to image/directory containing images.
inferencer (Inferencer): Inferencer to use.
save_path (Path, optional): Path to save the output image. If this is None, the output is visualized.
overlay (bool, optional): Overlay the segmentation mask on the image. It assumes that the task is segmentation.
"""
# Perform inference for the given image or image path. if image
# path is provided, `predict` method will read the image from
# file for convenience. We set the superimpose flag to True
# to overlay the predicted anomaly map on top of the input image.
output = inferencer.predict(image=image_path, superimpose=True)
output = inferencer.predict(image=image_path, superimpose=True, overlay_mask=overlay)

# Incase both anomaly map and scores are returned add scores to the image.
if isinstance(output, tuple):
Expand Down

0 comments on commit cab7aa2

Please sign in to comment.