Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizer improvements pt1 #293

Merged
merged 8 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 84 additions & 21 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import math
from pathlib import Path
from typing import Optional, Tuple

import cv2
import matplotlib.figure
import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -29,41 +31,28 @@ class Visualizer:
either be logged by accessing the `figure` attribute or can be saved directly by calling `save()` method.

Example:
>>> visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
>>> visualizer = Visualizer()
>>> visualizer.add_image(image=image, title="Image")
>>> visualizer.close()

Args:
num_rows (int): Number of rows of images in the figure.
num_cols (int): Number of columns/images in each row.
figure_size (Tuple[int, int]): Size of output figure
"""

def __init__(self, num_rows: int, num_cols: int, figure_size: Tuple[int, int]):
self.figure_index: int = 0
def __init__(self):

self.figure, self.axis = plt.subplots(num_rows, num_cols, figsize=figure_size)
self.figure.subplots_adjust(right=0.9)
self.images = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add annotation to self.images


for axis in self.axis:
axis.axes.xaxis.set_visible(False)
axis.axes.yaxis.set_visible(False)
self.figure: matplotlib.figure.Figure
self.axis: np.ndarray

def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = None, index: Optional[int] = None):
def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = None):
"""Add image to figure.

Args:
image (np.ndarray): Image which should be added to the figure.
title (str): Image title shown on the plot.
color_map (Optional[str]): Name of matplotlib color map used to map scalar data to colours. Defaults to None.
index (Optional[int]): Figure index. Defaults to None.
"""
if index is None:
index = self.figure_index
self.figure_index += 1

self.axis[index].imshow(image, color_map, vmin=0, vmax=255)
self.axis[index].title.set_text(title)
image_data = dict(image=image, title=title, color_map=color_map)
self.images.append(image_data)

def add_text(self, image: np.ndarray, text: str, font: int = cv2.FONT_HERSHEY_PLAIN):
"""Puts text on an image.
Expand All @@ -86,8 +75,81 @@ def add_text(self, image: np.ndarray, text: str, font: int = cv2.FONT_HERSHEY_PL
cv2.putText(image, line.strip(), (0, (baseline // 2 + text_h) + offset), font, font_size, (0, 0, 255))
return image

@staticmethod
def add_label(
image: np.ndarray,
label_name: str,
color: Tuple[int, int, int],
confidence: Optional[float] = None,
font_scale: float = 5e-3,
thickness_scale=1e-3,
):
"""Adds a label to an image.

Args:
image (np.ndarray): Input image.
label_name (str): Name of the label that will be displayed on the image.
color (Tuple[int, int, int]): RGB values for background color of label.
confidence (Optional[float]): confidence score of the label.
font_scale (float): scale of the font size relative to image size. Increase for bigger font.
thickness_scale (float): scale of the font thickness. Increase for thicker font.

Returns:
np.ndarray: Image with label.
"""
image = image.copy()
img_height, img_width, _ = image.shape

font = cv2.FONT_HERSHEY_PLAIN
text = label_name if confidence is None else f"{label_name}: {confidence:.2}"

# get font sizing
font_scale = min(img_width, img_height) * font_scale
thickness = math.ceil(min(img_width, img_height) * thickness_scale)
(width, height), baseline = cv2.getTextSize(text, font, fontScale=font_scale, thickness=thickness)

# create label
label_patch = np.zeros((height + baseline, width + baseline, 3), dtype=np.uint8)
label_patch[:, :] = color
cv2.putText(
label_patch,
text,
(0, baseline // 2 + height),
font,
fontScale=font_scale,
thickness=thickness,
color=0,
lineType=cv2.LINE_AA,
)

# add label to image
image[: baseline + height, : baseline + width] = label_patch
return image

def add_normal_label(self, image: np.ndarray, confidence: Optional[float] = None):
"""Adds the normal label to the image."""
return self.add_label(image, "normal", (225, 252, 134), confidence)

def add_anomalous_label(self, image: np.ndarray, confidence: Optional[float] = None):
"""Adds the anomalous label to the image."""
return self.add_label(image, "anomalous", (255, 100, 100), confidence)

def generate(self):
"""Generate the image."""
num_cols = len(self.images)
figure_size = (num_cols * 3, 3)
self.figure, self.axis = plt.subplots(1, num_cols, figsize=figure_size)
self.figure.subplots_adjust(right=0.9)

for axis, image_dict in zip(self.axis, self.images):
axis.axes.xaxis.set_visible(False)
axis.axes.yaxis.set_visible(False)
axis.imshow(image_dict["image"], image_dict["color_map"], vmin=0, vmax=255)
axis.title.set_text(image_dict["title"])

def show(self):
"""Show image on a matplotlib figure."""
self.generate()
self.figure.show()

def save(self, filename: Path):
Expand All @@ -96,6 +158,7 @@ def save(self, filename: Path):
Args:
filename (Path): Filename to save image
"""
self.generate()
filename.parent.mkdir(parents=True, exist_ok=True)
self.figure.savefig(filename, dpi=100)

Expand Down
36 changes: 18 additions & 18 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,24 @@ def on_test_batch_end(
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

num_cols = 6 if self.task == "segmentation" else 5
visualizer = Visualizer(num_rows=1, num_cols=num_cols, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")

if "mask" in outputs:
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")

visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")

image_classified = visualizer.add_text(
image=image,
text=f"""Pred: { "anomalous" if pred_score > threshold else "normal"}({pred_score:.3f}) \n
GT: {"anomalous" if bool(gt_label) else "normal"}""",
)
visualizer.add_image(image=image_classified, title="Classified Image")
visualizer = Visualizer()

if self.task == "segmentation":
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
visualizer.add_image(image=image, title="Image")
if "mask" in outputs:
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
elif self.task == "classification":
gt_im = visualizer.add_anomalous_label(image) if gt_label else visualizer.add_normal_label(image)
visualizer.add_image(gt_im, title="Image/True label")
if pred_score >= threshold:
image_classified = visualizer.add_anomalous_label(heat_map, pred_score)
else:
image_classified = visualizer.add_normal_label(heat_map, 1 - pred_score)
visualizer.add_image(image=image_classified, title="Prediction")

self._add_images(visualizer, pl_module, trainer, Path(filename))
visualizer.close()
Expand Down
2 changes: 1 addition & 1 deletion tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def add_label(prediction: np.ndarray, scores: float, font: int = cv2.FONT_HERSHE
(width, height), baseline = cv2.getTextSize(text, font, font_size, thickness=font_size // 2)
label_patch = np.zeros((height + baseline, width + baseline, 3), dtype=np.uint8)
label_patch[:, :] = (225, 252, 134)
cv2.putText(label_patch, text, (0, baseline // 2 + height), font, font_size, 0)
cv2.putText(label_patch, text, (0, baseline // 2 + height), font, font_size, 0, lineType=cv2.LINE_AA)
prediction[: baseline + height, : baseline + width] = label_patch
return prediction

Expand Down