Skip to content

Commit

Permalink
Visualizer show classification and segmentation (#178)
Browse files Browse the repository at this point in the history
* Visualizer show classification and segmentation

* get enumrated output label

* styling

* styling

* fix tests

* put label on image, image order

* styling
  • Loading branch information
alexriedel1 committed Apr 8, 2022
1 parent 83b4d54 commit 548852f
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 22 deletions.
22 changes: 22 additions & 0 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from typing import Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -64,6 +65,27 @@ def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = No
self.axis[index].imshow(image, color_map, vmin=0, vmax=255)
self.axis[index].title.set_text(title)

def add_text(self, image: np.ndarray, text: str, font: int = cv2.FONT_HERSHEY_PLAIN):
"""Puts text on an image.
Args:
image (np.ndarray): Input image.
text (str): Text to add.
font (Optional[int]): cv2 font type. Defaults to 0.
Returns:
np.ndarray: Image with text.
"""
image = image.copy()
font_size = image.shape[1] // 256 + 1 # Text scale is calculated based on the reference size of 256

for i, line in enumerate(text.split("\n")):
(text_w, text_h), baseline = cv2.getTextSize(line.strip(), font, font_size, thickness=1)
offset = i * text_h
cv2.rectangle(image, (0, offset + baseline // 2), (0 + text_w, 0 + text_h + offset), (255, 255, 255), -1)
cv2.putText(image, line.strip(), (0, (baseline // 2 + text_h) + offset), font, font_size, (0, 0, 255))
return image

def show(self):
"""Show image on a matplotlib figure."""
self.figure.show()
Expand Down
6 changes: 5 additions & 1 deletion anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")

if not config.project.log_images_to == []:
callbacks.append(VisualizerCallback(inputs_are_normalized=not config.model.normalization_method == "none"))
callbacks.append(
VisualizerCallback(
task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none"
)
)

if "optimization" in config.keys():
if "nncf" in config.optimization and config.optimization.nncf.apply:
Expand Down
55 changes: 36 additions & 19 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class VisualizerCallback(Callback):
config.yaml file.
"""

def __init__(self, inputs_are_normalized: bool = True):
def __init__(self, task: str, inputs_are_normalized: bool = True):
"""Visualizer callback."""
self.task = task
self.inputs_are_normalized = inputs_are_normalized

def _add_images(
Expand Down Expand Up @@ -111,26 +112,42 @@ def on_test_batch_end(
normalize = True # raw anomaly maps. Still need to normalize
threshold = pl_module.pixel_metrics.F1.threshold

if isinstance(outputs, dict) and "mask" in outputs.keys():
for (filename, image, true_mask, anomaly_map) in zip(
outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"]
):
image = Denormalize()(image.cpu())
true_mask = true_mask.cpu().numpy() * 255
anomaly_map = anomaly_map.cpu().numpy()
for i, (filename, image, anomaly_map, pred_score, gt_label) in enumerate(
zip(
outputs["image_path"],
outputs["image"],
outputs["anomaly_maps"],
outputs["pred_scores"],
outputs["label"],
)
):
image = Denormalize()(image.cpu())
anomaly_map = anomaly_map.cpu().numpy()
heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

num_cols = 6 if self.task == "segmentation" else 5
visualizer = Visualizer(num_rows=1, num_cols=num_cols, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")

if self.task == "segmentation":
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")

heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")

visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()
image_classified = visualizer.add_text(
image=image,
text=f"""Pred: { "anomalous" if pred_score > threshold else "normal"}({pred_score:.3f}) \n
GT: {"anomalous" if bool(gt_label) else "normal"}""",
)
visualizer.add_image(image=image_classified, title="Classified Image")

self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()

def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Sync logs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__()
self.save_hyperparameters(hparams)
self.loss_fn = nn.NLLLoss()
self.callbacks = [VisualizerCallback()] # test if this is removed
self.callbacks = [VisualizerCallback(task="segmentation")] # test if this is removed

self.image_threshold = AdaptiveThreshold(hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(hparams.model.threshold.pixel_default).cpu()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
self.model = DummyModel()
self.task = "segmentation"
self.callbacks = [VisualizerCallback()] # test if this is removed
self.callbacks = [VisualizerCallback(task=self.task)] # test if this is removed

def test_step(self, batch, _):
"""Only used to trigger on_test_epoch_end."""
Expand Down

0 comments on commit 548852f

Please sign in to comment.