Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizer show classification and segmentation #178

Merged
merged 11 commits into from
Apr 8, 2022
22 changes: 22 additions & 0 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pathlib import Path
from typing import Optional, Tuple

import cv2
import matplotlib.pyplot as plt
import numpy as np

Expand Down Expand Up @@ -64,6 +65,27 @@ def add_image(self, image: np.ndarray, title: str, color_map: Optional[str] = No
self.axis[index].imshow(image, color_map, vmin=0, vmax=255)
self.axis[index].title.set_text(title)

def add_text(self, image: np.ndarray, text: str, font: int = cv2.FONT_HERSHEY_PLAIN):
"""Puts text on an image.
Args:
image (np.ndarray): Input image.
text (str): Text to add.
font (Optional[int]): cv2 font type. Defaults to 0.
Returns:
np.ndarray: Image with text.
"""
image = image.copy()
font_size = image.shape[1] // 256 + 1 # Text scale is calculated based on the reference size of 256

for i, line in enumerate(text.split("\n")):
(text_w, text_h), baseline = cv2.getTextSize(line.strip(), font, font_size, thickness=1)
offset = i * text_h
cv2.rectangle(image, (0, offset + baseline // 2), (0 + text_w, 0 + text_h + offset), (255, 255, 255), -1)
cv2.putText(image, line.strip(), (0, (baseline // 2 + text_h) + offset), font, font_size, (0, 0, 255))
return image

def show(self):
"""Show image on a matplotlib figure."""
self.figure.show()
Expand Down
6 changes: 5 additions & 1 deletion anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
raise ValueError(f"Normalization method not recognized: {config.model.normalization_method}")

if not config.project.log_images_to == []:
callbacks.append(VisualizerCallback(inputs_are_normalized=not config.model.normalization_method == "none"))
callbacks.append(
VisualizerCallback(
task=config.dataset.task, inputs_are_normalized=not config.model.normalization_method == "none"
)
)

if "optimization" in config.keys():
if "nncf" in config.optimization and config.optimization.nncf.apply:
Expand Down
55 changes: 36 additions & 19 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ class VisualizerCallback(Callback):
config.yaml file.
"""

def __init__(self, inputs_are_normalized: bool = True):
def __init__(self, task: str, inputs_are_normalized: bool = True):
"""Visualizer callback."""
self.task = task
self.inputs_are_normalized = inputs_are_normalized

def _add_images(
Expand Down Expand Up @@ -111,26 +112,42 @@ def on_test_batch_end(
normalize = True # raw anomaly maps. Still need to normalize
threshold = pl_module.pixel_metrics.F1.threshold

if isinstance(outputs, dict) and "mask" in outputs.keys():
for (filename, image, true_mask, anomaly_map) in zip(
outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"]
):
image = Denormalize()(image.cpu())
true_mask = true_mask.cpu().numpy() * 255
anomaly_map = anomaly_map.cpu().numpy()
for i, (filename, image, anomaly_map, pred_score, gt_label) in enumerate(
zip(
outputs["image_path"],
outputs["image"],
outputs["anomaly_maps"],
outputs["pred_scores"],
outputs["label"],
)
):
image = Denormalize()(image.cpu())
anomaly_map = anomaly_map.cpu().numpy()
heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

num_cols = 6 if self.task == "segmentation" else 5
visualizer = Visualizer(num_rows=1, num_cols=num_cols, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")

if self.task == "segmentation":
true_mask = outputs["mask"][i].cpu().numpy() * 255
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")

heat_map = superimpose_anomaly_map(anomaly_map, image, normalize=normalize)
pred_mask = compute_mask(anomaly_map, threshold)
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")

visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
visualizer.add_image(image=image, title="Image")
visualizer.add_image(image=true_mask, color_map="gray", title="Ground Truth")
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()
image_classified = visualizer.add_text(
image=image,
text=f"""Pred: { "anomalous" if pred_score > threshold else "normal"}({pred_score:.3f}) \n
GT: {"anomalous" if bool(gt_label) else "normal"}""",
)
visualizer.add_image(image=image_classified, title="Classified Image")

self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()

def on_test_end(self, _trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""Sync logs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__()
self.save_hyperparameters(hparams)
self.loss_fn = nn.NLLLoss()
self.callbacks = [VisualizerCallback()] # test if this is removed
self.callbacks = [VisualizerCallback(task="segmentation")] # test if this is removed

self.image_threshold = AdaptiveThreshold(hparams.model.threshold.image_default).cpu()
self.pixel_threshold = AdaptiveThreshold(hparams.model.threshold.pixel_default).cpu()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
self.model = DummyModel()
self.task = "segmentation"
self.callbacks = [VisualizerCallback()] # test if this is removed
self.callbacks = [VisualizerCallback(task=self.task)] # test if this is removed

def test_step(self, batch, _):
"""Only used to trigger on_test_epoch_end."""
Expand Down