-
Notifications
You must be signed in to change notification settings - Fork 639
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Refactor `VisualizerCallback`, fix Wandb bug * `VisualizerCallback` is refactored into a base `VisualizerCallbackBase` class which is used to construct `VisualizerCallbackImage` via inheritance. * Fix bug in `VisualizerCallback.on_test_end`, where `pl_module.logger` was accessed instead of iterating over `trainer.loggers` (wandb errort silently before iirc) * Skeleton for `VisualizerCallbackMetric` is added * Log Figure with built-in SummaryWriter funtcion PyTorch provides a function for logging Figure objects, so lets not reinvent the wheel. https://pytorch.org/docs/stable/tensorboard.html?highlight=add_figure#torch.utils.tensorboard.writer.SummaryWriter.add_figure * Clear `AnomalibWandbLogger.image_list` on save * Explicitly close figure in `ImageGrid.generate` Otherwise, matplotlib starts to complain and objects aren't gced properly * Shift `Visualizer` to `VisualizerCallbackBase` `Visualizer` is needed by all Callbacks for writing images to disk * Add first metric visualization Every metric that should be visualized needs to implement its own `generate_figure` function, and the resulting plot is then saved by `VisualizerCallbackMetric` * Add AUPR metric * Make AUPR accessible * Add AUPRO metric and its vizualiation * Needed to change signature of `AnomalyModule._collect_outputs` for pixel-wise metrics as we need spatial dimensions for connected-componenta-analysis in AUPRO * Add AUPRO, which uses kornia and the fact that per-region overlap == per-region tpr for fast AUPRO computation * Updated docstrings for AUPR/AUROC * Adjust tests, CLI and notebooks Due to bugs in CLI, feature could not be tested well * Directly access `AnomalyModule's` metrics * Bugfix of AUPRO implementation Since unique fpr/tpr curves are generated for each label, we cannot use fpr_index across calls ro `roc`. Instead, we now bilinearly resample generated pro curves at `fpr <= self.fpr_limit` to fixed sampling points, and then aggregate over the resampled curve. * Add types as required by mypy * Improve variable naming/wording in AUPRO We follow a scheme similar to the ROC curve, where the PRO curve is composed of per-region TPR plotted against global FPR. * Move visualizer callbacks to their own module * Rename `VisualizerCallback`s We now prepend the last part, i.e. `VisualizerCallbackBase` -> `BaseVisualizerCallback` * Fix errors in visualizer and in ganomaly config * Revert change in `max_epoch` default for `ganomaly` * Also revert change in default model Co-authored-by: Samet Akcay <samet.akcay@intel.com>
- Loading branch information
1 parent
ddd1b50
commit ba27019
Showing
20 changed files
with
1,584 additions
and
1,073 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
"""Callbacks to visualize anomaly detection performance.""" | ||
from .visualizer_base import BaseVisualizerCallback | ||
from .visualizer_image import ImageVisualizerCallback | ||
from .visualizer_metric import MetricVisualizerCallback | ||
|
||
__all__ = ["BaseVisualizerCallback", "ImageVisualizerCallback", "MetricVisualizerCallback"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""Base Visualizer Callback.""" | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
|
||
from pathlib import Path | ||
from typing import Union, cast | ||
|
||
import numpy as np | ||
import pytorch_lightning as pl | ||
from pytorch_lightning import Callback | ||
|
||
from anomalib.models.components import AnomalyModule | ||
from anomalib.post_processing import Visualizer | ||
from anomalib.utils.loggers import AnomalibWandbLogger | ||
from anomalib.utils.loggers.base import ImageLoggerBase | ||
|
||
|
||
class BaseVisualizerCallback(Callback): | ||
"""Callback that visualizes the results of a model. | ||
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the | ||
config.yaml file. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
task: str, | ||
mode: str, | ||
image_save_path: str, | ||
inputs_are_normalized: bool = True, | ||
show_images: bool = False, | ||
log_images: bool = True, | ||
save_images: bool = True, | ||
): | ||
"""Visualizer callback.""" | ||
if mode not in ["full", "simple"]: | ||
raise ValueError(f"Unknown visualization mode: {mode}. Please choose one of ['full', 'simple']") | ||
self.mode = mode | ||
if task not in ["classification", "segmentation"]: | ||
raise ValueError(f"Unknown task type: {mode}. Please choose one of ['classification', 'segmentation']") | ||
self.task = task | ||
self.inputs_are_normalized = inputs_are_normalized | ||
self.show_images = show_images | ||
self.log_images = log_images | ||
self.save_images = save_images | ||
self.image_save_path = Path(image_save_path) | ||
|
||
self.visualizer = Visualizer(mode, task) | ||
|
||
def _add_to_logger( | ||
self, | ||
image: np.ndarray, | ||
module: AnomalyModule, | ||
trainer: pl.Trainer, | ||
filename: Union[Path, str], | ||
): | ||
"""Log image from a visualizer to each of the available loggers in the project. | ||
Args: | ||
image (np.ndarray): Image that should be added to the loggers. | ||
module (AnomalyModule): Anomaly module. | ||
trainer (Trainer): Pytorch Lightning trainer which holds reference to `logger` | ||
filename (Path): Path of the input image. This name is used as name for the generated image. | ||
""" | ||
# Store names of logger and the logger in a dict | ||
available_loggers = { | ||
type(logger).__name__.lower().rstrip("logger").lstrip("anomalib"): logger for logger in trainer.loggers | ||
} | ||
# save image to respective logger | ||
if self.log_images: | ||
for log_to in available_loggers: | ||
# check if logger object is same as the requested object | ||
if isinstance(available_loggers[log_to], ImageLoggerBase): | ||
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy | ||
if isinstance(filename, Path): | ||
_name = filename.parent.name + "_" + filename.name | ||
elif isinstance(filename, str): | ||
_name = filename | ||
logger.add_image( | ||
image=image, | ||
name=_name, | ||
global_step=module.global_step, | ||
) | ||
|
||
def on_test_end(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None: | ||
"""Sync logs. | ||
Currently only ``AnomalibWandbLogger.save`` is called from this method. | ||
This is because logging as a single batch ensures that all images appear as part of the same step. | ||
Args: | ||
trainer (pl.Trainer): Pytorch Lightning trainer | ||
pl_module (AnomalyModule): Anomaly module (unused) | ||
""" | ||
for logger in trainer.loggers: | ||
if isinstance(logger, AnomalibWandbLogger): | ||
logger.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
"""Metric Visualizer Callback.""" | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
|
||
from pathlib import Path | ||
|
||
import numpy as np | ||
import pytorch_lightning as pl | ||
from matplotlib import pyplot as plt | ||
from pytorch_lightning.utilities.cli import CALLBACK_REGISTRY | ||
|
||
from anomalib.models.components import AnomalyModule | ||
|
||
from .visualizer_base import BaseVisualizerCallback | ||
|
||
|
||
@CALLBACK_REGISTRY | ||
class MetricVisualizerCallback(BaseVisualizerCallback): | ||
"""Callback that visualizes the metric results of a model by plotting the corresponding curves. | ||
To save the images to the filesystem, add the 'local' keyword to the `project.log_images_to` parameter in the | ||
config.yaml file. | ||
""" | ||
|
||
def on_test_end(self, trainer: pl.Trainer, pl_module: AnomalyModule) -> None: | ||
"""Log images of the metrics contained in pl_module. | ||
In order to also plot custom metrics, they need to have implemented a `generate_figure` function that returns | ||
Tuple[matplotlib.figure.Figure, str]. | ||
Args: | ||
trainer (pl.Trainer): pytorch lightning trainer. | ||
pl_module (AnomalyModule): pytorch lightning module. | ||
""" | ||
|
||
if self.save_images or self.log_images: | ||
for metrics in (pl_module.image_metrics, pl_module.pixel_metrics): | ||
for metric in metrics.values(): | ||
# `generate_figure` needs to be defined for every metric that should be plotted automatically | ||
if hasattr(metric, "generate_figure"): | ||
fig, log_name = metric.generate_figure() | ||
file_name = f"{metrics.prefix}{log_name}" | ||
if self.log_images: | ||
self._add_to_logger(fig, pl_module, trainer, file_name) | ||
|
||
if self.save_images: | ||
fig.canvas.draw() | ||
# convert figure to np.ndarray for saving via visualizer | ||
img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | ||
img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | ||
self.visualizer.save(Path(self.image_save_path.joinpath(f"{file_name}.png")), img) | ||
plt.close(fig) | ||
super().on_test_end(trainer, pl_module) |
Oops, something went wrong.