Skip to content

Commit

Permalink
Visualizer improvements pt1 (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
djdameln committed May 5, 2022
1 parent 60d9c12 commit 6e518ac
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 66 deletions.
11 changes: 10 additions & 1 deletion anomalib/post_processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
# and limitations under the License.

from .post_process import (
add_anomalous_label,
add_normal_label,
anomaly_map_to_color_map,
compute_mask,
superimpose_anomaly_map,
)
from .visualizer import Visualizer

__all__ = ["anomaly_map_to_color_map", "superimpose_anomaly_map", "compute_mask", "Visualizer"]
__all__ = [
"add_anomalous_label",
"add_normal_label",
"anomaly_map_to_color_map",
"superimpose_anomaly_map",
"compute_mask",
"Visualizer",
]
64 changes: 64 additions & 0 deletions anomalib/post_processing/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,75 @@
# and limitations under the License.


import math
from typing import Optional, Tuple

import cv2
import numpy as np
from skimage import morphology


def add_label(
image: np.ndarray,
label_name: str,
color: Tuple[int, int, int],
confidence: Optional[float] = None,
font_scale: float = 5e-3,
thickness_scale=1e-3,
):
"""Adds a label to an image.
Args:
image (np.ndarray): Input image.
label_name (str): Name of the label that will be displayed on the image.
color (Tuple[int, int, int]): RGB values for background color of label.
confidence (Optional[float]): confidence score of the label.
font_scale (float): scale of the font size relative to image size. Increase for bigger font.
thickness_scale (float): scale of the font thickness. Increase for thicker font.
Returns:
np.ndarray: Image with label.
"""
image = image.copy()
img_height, img_width, _ = image.shape

font = cv2.FONT_HERSHEY_PLAIN
text = label_name if confidence is None else f"{label_name} ({confidence*100:.0f}%)"

# get font sizing
font_scale = min(img_width, img_height) * font_scale
thickness = math.ceil(min(img_width, img_height) * thickness_scale)
(width, height), baseline = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)

# create label
label_patch = np.zeros((height + baseline, width + baseline, 3), dtype=np.uint8)
label_patch[:, :] = color
cv2.putText(
label_patch,
text,
(0, baseline // 2 + height),
font,
fontScale=font_scale,
thickness=thickness,
color=0,
lineType=cv2.LINE_AA,
)

# add label to image
image[: baseline + height, : baseline + width] = label_patch
return image


def add_normal_label(image: np.ndarray, confidence: Optional[float] = None):
"""Adds the normal label to the image."""
return add_label(image, "normal", (225, 252, 134), confidence)


def add_anomalous_label(image: np.ndarray, confidence: Optional[float] = None):
"""Adds the anomalous label to the image."""
return add_label(image, "anomalous", (255, 100, 100), confidence)


def anomaly_map_to_color_map(anomaly_map: np.ndarray, normalize: bool = True) -> np.ndarray:
"""Compute anomaly color heatmap.
Expand Down
68 changes: 25 additions & 43 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
# and limitations under the License.

from pathlib import Path
from typing import Optional, Tuple
from typing import Dict, List, Optional

import cv2
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -29,65 +29,46 @@ class Visualizer:
either be logged by accessing the `figure` attribute or can be saved directly by calling `save()` method.
Example:
>>> visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
>>> visualizer = Visualizer()
>>> visualizer.add_image(image=image, title="Image")
>>> visualizer.close()
Args:
num_rows (int): Number of rows of images in the figure.
num_cols (int): Number of columns/images in each row.
figure_size (Tuple[int, int]): Size of output figure
"""

def __init__(self, num_rows: int, num_cols: int, figure_size: Tuple[int, int]):
self.figure_index: int = 0
def __init__(self):

self.figure, self.axis = plt.subplots(num_rows, num_cols, figsize=figure_size)
self.figure.subplots_adjust(right=0.9)
self.images: List[Dict] = []

for axis in self.axis:
axis.axes.xaxis.set_visible(False)
axis.axes.yaxis.set_visible(False)
self.figure: matplotlib.figure.Figure
self.axis: np.ndarray

def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = None, index: Optional[int] = None):
def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = None):
"""Add image to figure.
Args:
image (np.ndarray): Image which should be added to the figure.
title (str): Image title shown on the plot.
color_map (Optional[str]): Name of matplotlib color map used to map scalar data to colours. Defaults to None.
index (Optional[int]): Figure index. Defaults to None.
"""
if index is None:
index = self.figure_index
self.figure_index += 1

self.axis[index].imshow(image, color_map, vmin=0, vmax=255)
self.axis[index].title.set_text(title)

def add_text(self, image: np.ndarray, text: str, font: int = cv2.FONT_HERSHEY_PLAIN):
"""Puts text on an image.
Args:
image (np.ndarray): Input image.
text (str): Text to add.
font (Optional[int]): cv2 font type. Defaults to 0.
Returns:
np.ndarray: Image with text.
"""
image = image.copy()
font_size = image.shape[1] // 256 + 1 # Text scale is calculated based on the reference size of 256
image_data = dict(image=image, title=title, color_map=color_map)
self.images.append(image_data)

def generate(self):
"""Generate the image."""
num_cols = len(self.images)
figure_size = (num_cols * 3, 3)
self.figure, self.axis = plt.subplots(1, num_cols, figsize=figure_size)
self.figure.subplots_adjust(right=0.9)

for i, line in enumerate(text.split("\n")):
(text_w, text_h), baseline = cv2.getTextSize(line.strip(), font, font_size, thickness=1)
offset = i * text_h
cv2.rectangle(image, (0, offset + baseline // 2), (0 + text_w, 0 + text_h + offset), (255, 255, 255), -1)
cv2.putText(image, line.strip(), (0, (baseline // 2 + text_h) + offset), font, font_size, (0, 0, 255))
return image
axes = self.axis if len(self.images) > 1 else [self.axis]
for axis, image_dict in zip(axes, self.images):
axis.axes.xaxis.set_visible(False)
axis.axes.yaxis.set_visible(False)
axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
axis.title.set_text(image_dict["title"])

def show(self):
"""Show image on a matplotlib figure."""
self.generate()
self.figure.show()

def save(self, filename: Path):
Expand All @@ -96,6 +77,7 @@ def save(self, filename: Path):
Args:
filename (Path): Filename to save image
"""
self.generate()
filename.parent.mkdir(parents=True, exist_ok=True)
self.figure.savefig(filename, dpi=100)

Expand Down
45 changes: 26 additions & 19 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
from skimage.segmentation import mark_boundaries

from anomalib.models.components import AnomalyModule
from anomalib.post_processing import Visualizer, compute_mask, superimpose_anomaly_map
from anomalib.post_processing import (
Visualizer,
add_anomalous_label,
add_normal_label,
compute_mask,
superimpose_anomaly_map,
)
from anomalib.pre_processing.transforms import Denormalize
from anomalib.utils import loggers
from anomalib.utils.loggers import AnomalibWandbLogger
Expand Down Expand Up @@ -133,25 +139,26 @@ def on_test_batch_end(
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

num_cols = 6 if self.task == "segmentation" else 5
visualizer = Visualizer(num_rows=1, num_cols=num_cols, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")

if "mask" in outputs:
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")

visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")

image_classified = visualizer.add_text(
image=image,
text=f"""Pred: { "anomalous" if pred_score > threshold else "normal"}({pred_score:.3f}) \n
GT: {"anomalous" if bool(gt_label) else "normal"}""",
)
visualizer.add_image(image=image_classified, title="Classified Image")
visualizer = Visualizer()

if self.task == "segmentation":
visualizer.add_image(image=image, title="Image")
if "mask" in outputs:
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
elif self.task == "classification":
gt_im = add_anomalous_label(image) if gt_label else add_normal_label(image)
visualizer.add_image(gt_im, title="Image/True label")
if pred_score >= threshold:
image_classified = add_anomalous_label(heat_map, pred_score)
else:
image_classified = add_normal_label(heat_map, 1 - pred_score)
visualizer.add_image(image=image_classified, title="Prediction")

visualizer.generate()
self._add_images(visualizer, pl_module, trainer, Path(filename))
visualizer.close()

Expand Down
5 changes: 3 additions & 2 deletions tests/pre_merge/post_processing/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ def test_visualize_fully_defected_masks():
"""Test if a fully defected anomaly mask results in a completely white image."""

# create visualizer and add fully defected mask
visualizer = Visualizer(num_rows=1, num_cols=2, figure_size=(3, 3))
visualizer = Visualizer()
mask = np.ones((256, 256)) * 255
visualizer.add_image(image=mask, color_map="gray", title="fully defected mask")
visualizer.generate()

# retrieve plotted image
canvas = FigureCanvas(visualizer.figure)
canvas.draw()
plotted_img = visualizer.axis[0].images[0].make_image(canvas.renderer)
plotted_img = visualizer.axis.images[0].make_image(canvas.renderer)

# assert that the plotted image is completely white
assert np.all(plotted_img[0][..., 0] == 255)
2 changes: 1 addition & 1 deletion tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def add_label(prediction: np.ndarray, scores: float, font: int = cv2.FONT_HERSHE
(width, height), baseline = cv2.getTextSize(text, font, font_size, thickness=font_size // 2)
label_patch = np.zeros((height + baseline, width + baseline, 3), dtype=np.uint8)
label_patch[:, :] = (225, 252, 134)
cv2.putText(label_patch, text, (0, baseline // 2 + height), font, font_size, 0)
cv2.putText(label_patch, text, (0, baseline // 2 + height), font, font_size, 0, lineType=cv2.LINE_AA)
prediction[: baseline + height, : baseline + width] = label_patch
return prediction

Expand Down

0 comments on commit 6e518ac

Please sign in to comment.