Skip to content

Commit

Permalink
✨ Add metric visualizations (#429)
Browse files Browse the repository at this point in the history
* Refactor `VisualizerCallback`, fix Wandb bug

* `VisualizerCallback` is refactored into a base `VisualizerCallbackBase`
class which is used to construct `VisualizerCallbackImage` via inheritance.
* Fix bug in `VisualizerCallback.on_test_end`, where `pl_module.logger`
was accessed instead of iterating over `trainer.loggers`
(wandb errort silently before iirc)
* Skeleton for `VisualizerCallbackMetric` is added

* Log Figure with built-in SummaryWriter funtcion

PyTorch provides a function for logging Figure objects, so lets not
reinvent the wheel.

https://pytorch.org/docs/stable/tensorboard.html?highlight=add_figure#torch.utils.tensorboard.writer.SummaryWriter.add_figure

* Clear `AnomalibWandbLogger.image_list` on save

* Explicitly close figure in `ImageGrid.generate`

Otherwise, matplotlib starts to complain and objects aren't gced
properly

* Shift `Visualizer` to `VisualizerCallbackBase`

`Visualizer` is needed by all Callbacks for writing images to disk

* Add first metric visualization

Every metric that should be visualized needs to implement its own
`generate_figure` function, and the resulting plot is then saved by
`VisualizerCallbackMetric`

* Add AUPR metric

* Make AUPR accessible

* Add AUPRO metric and its vizualiation

* Needed to change signature of `AnomalyModule._collect_outputs` for
pixel-wise metrics as we need spatial dimensions for
connected-componenta-analysis in AUPRO
* Add AUPRO, which uses kornia and the fact that
per-region overlap == per-region tpr for fast AUPRO computation
* Updated docstrings for AUPR/AUROC

* Adjust tests, CLI and notebooks

Due to bugs in CLI, feature could not be tested well

* Directly access `AnomalyModule's` metrics

* Bugfix of AUPRO implementation

Since unique fpr/tpr curves are generated for each label,
we cannot use fpr_index across calls ro `roc`. Instead, we now
bilinearly resample generated pro curves at `fpr <= self.fpr_limit`
to fixed sampling points, and then aggregate over the resampled curve.

* Add types as required by mypy

* Improve variable naming/wording in AUPRO

We follow a scheme similar to the ROC curve, where the PRO curve is
composed of per-region TPR plotted against global FPR.

* Move visualizer callbacks to their own module

* Rename `VisualizerCallback`s

We now prepend the last part, i.e. `VisualizerCallbackBase` -> `BaseVisualizerCallback`

* Fix errors in visualizer and in ganomaly config

* Revert change in `max_epoch` default for `ganomaly`

* Also revert change in default model

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
ORippler and samet-akcay committed Jul 13, 2022
1 parent ddd1b50 commit ba27019
Show file tree
Hide file tree
Showing 20 changed files with 1,584 additions and 1,073 deletions.
2 changes: 1 addition & 1 deletion anomalib/models/components/base/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _collect_outputs(self, image_metric, pixel_metric, outputs):
image_metric.update(output["pred_scores"], output["label"].int())
if "mask" in output.keys() and "anomaly_maps" in output.keys():
pixel_metric.cpu()
pixel_metric.update(output["anomaly_maps"].flatten(), output["mask"].flatten().int())
pixel_metric.update(output["anomaly_maps"], output["mask"].int())

def _post_process(self, outputs):
"""Compute labels based on model predictions."""
Expand Down
3 changes: 0 additions & 3 deletions anomalib/models/ganomaly/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ metrics:
image:
- F1Score
- AUROC
pixel:
- F1Score
- AUROC
threshold:
image_default: 0
adaptive: true
Expand Down
10 changes: 6 additions & 4 deletions anomalib/post_processing/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class ImageResult:

def __post_init__(self):
"""Generate heatmap overlay and segmentations, convert masks to images."""
self.heat_map = superimpose_anomaly_map(self.anomaly_map, self.image, normalize=False)
if self.anomaly_map is not None:
self.heat_map = superimpose_anomaly_map(self.anomaly_map, self.image, normalize=False)
if self.pred_mask is not None and np.max(self.pred_mask) <= 1.0:
self.pred_mask *= 255
self.segmentations = mark_boundaries(self.image, self.pred_mask, color=(1, 0, 0), mode="thick")
Expand Down Expand Up @@ -86,7 +87,7 @@ def visualize_batch(self, batch: Dict) -> Iterator[np.ndarray]:
image=Denormalize()(batch["image"][i].cpu()),
pred_score=batch["pred_scores"][i].cpu().numpy().item(),
pred_label=batch["pred_labels"][i].cpu().numpy().item(),
anomaly_map=batch["anomaly_maps"][i].cpu().numpy(),
anomaly_map=batch["anomaly_maps"][i].cpu().numpy() if "anomaly_maps" in batch else None,
pred_mask=batch["pred_masks"][i].squeeze().int().cpu().numpy() if "pred_masks" in batch else None,
gt_mask=batch["mask"][i].squeeze().int().cpu().numpy() if "mask" in batch else None,
)
Expand Down Expand Up @@ -132,9 +133,9 @@ def _visualize_full(self, image_result: ImageResult):
elif self.task == "classification":
visualization.add_image(image_result.image, title="Image")
if image_result.pred_label:
image_classified = add_anomalous_label(image_result.heat_map, image_result.pred_score)
image_classified = add_anomalous_label(image_result.image, image_result.pred_score)
else:
image_classified = add_normal_label(image_result.heat_map, 1 - image_result.pred_score)
image_classified = add_normal_label(image_result.image, 1 - image_result.pred_score)
visualization.add_image(image=image_classified, title="Prediction")

return visualization.generate()
Expand Down Expand Up @@ -235,4 +236,5 @@ def generate(self) -> np.ndarray:
# convert canvas to numpy array to prepare for visualization with opencv
img = np.frombuffer(self.figure.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(self.figure.canvas.get_width_height()[::-1] + (3,))
plt.close(self.figure)
return img
26 changes: 14 additions & 12 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from .model_loader import LoadModelCallback
from .tiler_configuration import TilerConfigurationCallback
from .timer import TimerCallback
from .visualizer_callback import VisualizerCallback
from .visualizer import ImageVisualizerCallback, MetricVisualizerCallback

__all__ = [
"CdfNormalizationCallback",
Expand All @@ -40,7 +40,8 @@
"MinMaxNormalizationCallback",
"TilerConfigurationCallback",
"TimerCallback",
"VisualizerCallback",
"ImageVisualizerCallback",
"MetricVisualizerCallback",
]


Expand Down Expand Up @@ -174,14 +175,15 @@ def add_visualizer_callback(callbacks: List[Callback], config: Union[DictConfig,
if config.visualization.image_save_path
else config.project.path + "/images"
)
callbacks.append(
VisualizerCallback(
task=config.dataset.task,
mode=config.visualization.mode,
image_save_path=image_save_path,
inputs_are_normalized=not config.model.normalization_method == "none",
show_images=config.visualization.show_images,
log_images=config.visualization.log_images,
save_images=config.visualization.save_images,
for callback in (ImageVisualizerCallback, MetricVisualizerCallback):
callbacks.append(
callback(
task=config.dataset.task,
mode=config.visualization.mode,
image_save_path=image_save_path,
inputs_are_normalized=not config.model.normalization_method == "none",
show_images=config.visualization.show_images,
log_images=config.visualization.log_images,
save_images=config.visualization.save_images,
)
)
)
6 changes: 6 additions & 0 deletions anomalib/utils/callbacks/visualizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Callbacks to visualize anomaly detection performance."""
from .visualizer_base import BaseVisualizerCallback
from .visualizer_image import ImageVisualizerCallback
from .visualizer_metric import MetricVisualizerCallback

__all__ = ["BaseVisualizerCallback", "ImageVisualizerCallback", "MetricVisualizerCallback"]
109 changes: 109 additions & 0 deletions anomalib/utils/callbacks/visualizer/visualizer_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Base Visualizer Callback."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path
from typing import Union, cast

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Callback

from anomalib.models.components import AnomalyModule
from anomalib.post_processing import Visualizer
from anomalib.utils.loggers import AnomalibWandbLogger
from anomalib.utils.loggers.base import ImageLoggerBase


class BaseVisualizerCallback(Callback):
"""Callback that visualizes the results of a model.
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the
config.yaml file.
"""

def __init__(
self,
task: str,
mode: str,
image_save_path: str,
inputs_are_normalized: bool = True,
show_images: bool = False,
log_images: bool = True,
save_images: bool = True,
):
"""Visualizer callback."""
if mode not in ["full", "simple"]:
raise ValueError(f"Unknown visualization mode: {mode}. Please choose one of ['full', 'simple']")
self.mode = mode
if task not in ["classification", "segmentation"]:
raise ValueError(f"Unknown task type: {mode}. Please choose one of ['classification', 'segmentation']")
self.task = task
self.inputs_are_normalized = inputs_are_normalized
self.show_images = show_images
self.log_images = log_images
self.save_images = save_images
self.image_save_path = Path(image_save_path)

self.visualizer = Visualizer(mode, task)

def _add_to_logger(
self,
image: np.ndarray,
module: AnomalyModule,
trainer: pl.Trainer,
filename: Union[Path, str],
):
"""Log image from a visualizer to each of the available loggers in the project.
Args:
image (np.ndarray): Image that should be added to the loggers.
module (AnomalyModule): Anomaly module.
trainer (Trainer): Pytorch Lightning trainer which holds reference to `logger`
filename (Path): Path of the input image. This name is used as name for the generated image.
"""
# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().rstrip("logger").lstrip("anomalib"): logger for logger in trainer.loggers
}
# save image to respective logger
if self.log_images:
for log_to in available_loggers:
# check if logger object is same as the requested object
if isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
if isinstance(filename, Path):
_name = filename.parent.name + "_" + filename.name
elif isinstance(filename, str):
_name = filename
logger.add_image(
image=image,
name=_name,
global_step=module.global_step,
)

def on_test_end(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Sync logs.
Currently only ``AnomalibWandbLogger.save`` is called from this method.
This is because logging as a single batch ensures that all images appear as part of the same step.
Args:
trainer (pl.Trainer): Pytorch Lightning trainer
pl_module (AnomalyModule): Anomaly module (unused)
"""
for logger in trainer.loggers:
if isinstance(logger, AnomalibWandbLogger):
logger.save()
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Visualizer Callback."""
"""Image Visualizer Callback."""

# Copyright (C) 2020 Intel Corporation
#
Expand All @@ -15,22 +15,19 @@
# and limitations under the License.

from pathlib import Path
from typing import Any, Optional, cast
from typing import Any, Optional

import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY
from pytorch_lightning.utilities.types import STEP_OUTPUT

from anomalib.models.components import AnomalyModule
from anomalib.post_processing import Visualizer
from anomalib.utils.loggers import AnomalibWandbLogger
from anomalib.utils.loggers.base import ImageLoggerBase

from .visualizer_base import BaseVisualizerCallback


@CALLBACK_REGISTRY
class VisualizerCallback(Callback):
class ImageVisualizerCallback(BaseVisualizerCallback):
"""Callback that visualizes the inference results of a model.
The callback generates a figure showing the original image, the ground truth segmentation mask,
Expand All @@ -40,62 +37,6 @@ class VisualizerCallback(Callback):
config.yaml file.
"""

def __init__(
self,
task: str,
mode: str,
image_save_path: str,
inputs_are_normalized: bool = True,
show_images: bool = False,
log_images: bool = True,
save_images: bool = True,
):
"""Visualizer callback."""
if mode not in ["full", "simple"]:
raise ValueError(f"Unknown visualization mode: {mode}. Please choose one of ['full', 'simple']")
self.mode = mode
if task not in ["classification", "segmentation"]:
raise ValueError(f"Unknown task type: {mode}. Please choose one of ['classification', 'segmentation']")
self.task = task
self.inputs_are_normalized = inputs_are_normalized
self.show_images = show_images
self.log_images = log_images
self.save_images = save_images
self.image_save_path = Path(image_save_path)

self.visualizer = Visualizer(mode, task)

def _add_to_logger(
self,
image: np.ndarray,
module: AnomalyModule,
trainer: pl.Trainer,
filename: Path,
):
"""Log image from a visualizer to each of the available loggers in the project.
Args:
image (np.ndarray): Image that should be added to the loggers.
module (AnomalyModule): Anomaly module.
trainer (Trainer): Pytorch Lightning trainer which holds reference to `logger`
filename (Path): Path of the input image. This name is used as name for the generated image.
"""
# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().rstrip("logger").lstrip("anomalib"): logger for logger in trainer.loggers
}
# save image to respective logger
if self.log_images:
for log_to in available_loggers:
# check if logger object is same as the requested object
if isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
logger.add_image(
image=image,
name=filename.parent.name + "_" + filename.name,
global_step=module.global_step,
)

def on_predict_batch_end(
self,
_trainer: pl.Trainer,
Expand Down Expand Up @@ -155,16 +96,3 @@ def on_test_batch_end(
self._add_to_logger(image, pl_module, trainer, filename)
if self.show_images:
self.visualizer.show(str(filename), image)

def on_test_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Sync logs.
Currently only ``AnomalibWandbLogger`` is called from this method. This is because logging as a single batch
ensures that all images appear as part of the same step.
Args:
_trainer (pl.Trainer): Pytorch Lightning trainer (unused)
pl_module (AnomalyModule): Anomaly module
"""
if pl_module.logger is not None and isinstance(pl_module.logger, AnomalibWandbLogger):
pl_module.logger.save()
65 changes: 65 additions & 0 deletions anomalib/utils/callbacks/visualizer/visualizer_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""Metric Visualizer Callback."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from pathlib import Path

import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY

from anomalib.models.components import AnomalyModule

from .visualizer_base import BaseVisualizerCallback


@CALLBACK_REGISTRY
class MetricVisualizerCallback(BaseVisualizerCallback):
"""Callback that visualizes the metric results of a model by plotting the corresponding curves.
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the
config.yaml file.
"""

def on_test_end(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
"""Log images of the metrics contained in pl_module.
In order to also plot custom metrics, they need to have implemented a `generate_figure` function that returns
Tuple[matplotlib.figure.Figure, str].
Args:
trainer (pl.Trainer): pytorch lightning trainer.
pl_module (AnomalyModule): pytorch lightning module.
"""

if self.save_images or self.log_images:
for metrics in (pl_module.image_metrics, pl_module.pixel_metrics):
for metric in metrics.values():
# `generate_figure` needs to be defined for every metric that should be plotted automatically
if hasattr(metric, "generate_figure"):
fig, log_name = metric.generate_figure()
file_name = f"{metrics.prefix}{log_name}"
if self.log_images:
self._add_to_logger(fig, pl_module, trainer, file_name)

if self.save_images:
fig.canvas.draw()
# convert figure to np.ndarray for saving via visualizer
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,))
self.visualizer.save(Path(self.image_save_path.joinpath(f"{file_name}.png")), img)
plt.close(fig)
super().on_test_end(trainer, pl_module)
Loading

0 comments on commit ba27019

Please sign in to comment.