Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchMetrics #7

Merged
merged 21 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions anomalib/core/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Visualizer Callback."""
from pathlib import Path
from typing import Any, Optional
from warnings import warn

from pytorch_lightning import Callback, LightningModule, Trainer
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
from skimage.segmentation import mark_boundaries
from tqdm import tqdm

from anomalib import loggers
from anomalib.core.model import AnomalyModule
from anomalib.core.results import SegmentationResults
from anomalib.datasets.utils import Denormalize
from anomalib.utils.metrics import compute_threshold_and_f1_score
from anomalib.utils.post_process import compute_mask, superimpose_anomaly_map
from anomalib.utils.visualizer import Visualizer

Expand Down Expand Up @@ -57,33 +57,37 @@ def _add_images(
if "local" in module.hparams.project.log_images_to:
visualizer.save(Path(module.hparams.project.path) / "images" / filename.parent.name / filename.name)

def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> None:
"""Log images at the end of training.
def on_test_batch_end(
self,
_trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
_batch_idx: int,
_dataloader_idx: int,
) -> None:
"""Log images at the end of every batch.

Args:
_trainer (Trainer): Pytorch lightning trainer object (unused)
_trainer (Trainer): Pytorch lightning trainer object (unused).
pl_module (LightningModule): Lightning modules derived from BaseAnomalyLightning object as
currently only they support logging images.
outputs (Dict[str, Any]): Outputs of the current test step.
_batch (Any): Input batch of the current test step (unused).
_batch_idx (int): Index of the current test batch (unused).
_dataloader_idx (int): Index of the dataloader that yielded the current batch (unused).
"""
if isinstance(pl_module.results, SegmentationResults):
results = pl_module.results
else:
raise ValueError("Visualizer callback only supported for segmentation tasks.")

if results.images is None or results.true_masks is None or results.anomaly_maps is None:
raise ValueError("Result set cannot be empty!")

threshold, _ = compute_threshold_and_f1_score(results.true_masks, results.anomaly_maps)
assert outputs is not None

for (filename, image, true_mask, anomaly_map) in tqdm(
zip(results.filenames, results.images, results.true_masks, results.anomaly_maps),
desc="Saving Results",
total=len(results.filenames),
for (filename, image, true_mask, anomaly_map) in zip(
outputs["image_path"], outputs["image"], outputs["mask"], outputs["anomaly_maps"]
):
image = Denormalize()(image)
image = Denormalize()(image.cpu())
true_mask = true_mask.cpu().numpy()
anomaly_map = anomaly_map.cpu().numpy()

heat_map = superimpose_anomaly_map(anomaly_map, image)
pred_mask = compute_mask(anomaly_map, threshold)
pred_mask = compute_mask(anomaly_map, pl_module.threshold.item())
vis_img = mark_boundaries(image, pred_mask, color=(1, 0, 0), mode="thick")

visualizer = Visualizer(num_rows=1, num_cols=5, figure_size=(12, 3))
Expand All @@ -92,5 +96,5 @@ def on_test_epoch_end(self, _trainer: Trainer, pl_module: LightningModule) -> No
visualizer.add_image(image=heat_map, title="Predicted Heat Map")
visualizer.add_image(image=pred_mask, color_map="gray", title="Predicted Mask")
visualizer.add_image(image=vis_img, title="Segmentation Result")
self._add_images(visualizer, pl_module, filename)
self._add_images(visualizer, pl_module, Path(filename))
visualizer.close()
5 changes: 5 additions & 0 deletions anomalib/core/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Custom anomaly evaluation metrics."""
from .auroc import AUROC
from .optimal_f1 import OptimalF1

__all__ = ["AUROC", "OptimalF1"]
17 changes: 17 additions & 0 deletions anomalib/core/metrics/auroc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Implementation of AUROC metric based on TorchMetrics."""
from torch import Tensor
from torchmetrics import ROC
from torchmetrics.functional import auc


class AUROC(ROC):
"""Area under the ROC curve."""

def compute(self) -> Tensor:
"""First compute ROC curve, then compute area under the curve.

Returns:
Value of the AUROC metric
"""
fpr, tpr, _thresholds = super().compute()
return auc(fpr, tpr)
42 changes: 42 additions & 0 deletions anomalib/core/metrics/optimal_f1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Implementation of Optimal F1 score based on TorchMetrics."""
import torch
from torchmetrics import Metric, PrecisionRecallCurve


class OptimalF1(Metric):
"""Optimal F1 Metric.

Compute the optimal F1 score at the adaptive threshold, based on the F1 metric of the true labels and the
predicted anomaly scores.
"""

def __init__(self, num_classes: int, **kwargs):
super().__init__(**kwargs)

self.precision_recall_curve = PrecisionRecallCurve(num_classes=num_classes, compute_on_step=False)

self.threshold: torch.Tensor

# pylint: disable=arguments-differ
def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: # type: ignore
"""Update the precision-recall curve metric."""
self.precision_recall_curve.update(preds, target)

def compute(self) -> torch.Tensor:
"""Compute the value of the optimal F1 score.

Compute the F1 scores while varying the threshold. Store the optimal
threshold as attribute and return the maximum value of the F1 score.

Returns:
Value of the F1 score at the optimal threshold.
"""
precision: torch.Tensor
recall: torch.Tensor
thresholds: torch.Tensor

precision, recall, thresholds = self.precision_recall_curve.compute()
f1_score = (2 * precision * recall) / (precision + recall + 1e-10)
self.threshold = thresholds[torch.argmax(f1_score)]
optimal_f1_score = torch.max(f1_score)
return optimal_f1_score
50 changes: 27 additions & 23 deletions anomalib/core/model/anomaly_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks.base import Callback
from torch import nn
from torchmetrics import F1, MetricCollection

from anomalib.core.results import ClassificationResults, SegmentationResults
from anomalib.utils.metrics import compute_threshold_and_f1_score
from anomalib.core.metrics import AUROC, OptimalF1


class AnomalyModule(pl.LightningModule):
Expand All @@ -41,18 +41,21 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
self.save_hyperparameters(params)
self.loss: torch.Tensor
self.callbacks: List[Callback]
self.register_buffer("threshold", torch.Tensor([params.model.threshold.default]))
self.register_buffer("threshold", torch.tensor(params.model.threshold.default)) # pylint: disable=not-callable
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
self.threshold: torch.Tensor
djdameln marked this conversation as resolved.
Show resolved Hide resolved

self.model: nn.Module

self.results: Union[ClassificationResults, SegmentationResults]
if params.dataset.task == "classification":
self.results = ClassificationResults()
elif params.dataset.task == "segmentation":
self.results = SegmentationResults()
# metrics
self.image_metrics = MetricCollection(
[AUROC(num_classes=1, pos_label=1, compute_on_step=False)], prefix="image_"
)
if params.model.threshold.adaptive:
self.image_metrics.add_metrics([OptimalF1(num_classes=1)])
else:
raise NotImplementedError("Only Classification and Segmentation tasks are supported in this version.")
self.image_metrics.add_metrics([F1(num_classes=1, compute_on_step=False, threshold=self.threshold.item())])
if self.hparams.dataset.task == "segmentation":
self.pixel_metrics = self.image_metrics.clone(prefix="pixel_")

def forward(self, batch): # pylint: disable=arguments-differ
"""Forward-pass input tensor to the module.
Expand Down Expand Up @@ -96,33 +99,33 @@ def test_step(self, batch, _): # pylint: disable=arguments-differ

def validation_step_end(self, val_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each validation step."""
return self._post_process(val_step_outputs)
val_step_outputs = self._post_process(val_step_outputs)
self.image_metrics(val_step_outputs["pred_scores"], val_step_outputs["label"].int())
if self.hparams.dataset.task == "segmentation":
self.pixel_metrics(val_step_outputs["anomaly_maps"].flatten(), val_step_outputs["mask"].flatten().int())
return val_step_outputs

def test_step_end(self, test_step_outputs): # pylint: disable=arguments-differ
"""Called at the end of each validation step."""
return self._post_process(test_step_outputs)
"""Called at the end of each test step."""
return self.validation_step_end(test_step_outputs)

def validation_epoch_end(self, outputs):
"""Compute image-level performance metrics.
def validation_epoch_end(self, _outputs):
"""Compute threshold and performance metrics.

Args:
outputs: Batch of outputs from the validation step
"""
self.results.store_outputs(outputs)
if self.hparams.model.threshold.adaptive:
threshold, _ = compute_threshold_and_f1_score(self.results.true_labels, self.results.pred_scores)
self.threshold = torch.Tensor([threshold])
self.results.evaluate(self.threshold.item())
self.image_metrics.compute()
self.threshold = self.image_metrics.OptimalF1.threshold
self._log_metrics()

def test_epoch_end(self, outputs):
def test_epoch_end(self, _outputs):
"""Compute and save anomaly scores of the test set.

Args:
outputs: Batch of outputs from the validation step
"""
self.results.store_outputs(outputs)
self.results.evaluate(self.threshold.item())
self._log_metrics()

def _post_process(self, outputs, predict_labels=False):
Expand All @@ -137,5 +140,6 @@ def _post_process(self, outputs, predict_labels=False):

def _log_metrics(self):
"""Log computed performance metrics."""
for name, value in self.results.performance.items():
self.log(name=name, value=value, on_epoch=True, prog_bar=True)
self.log_dict(self.image_metrics)
if self.hparams.dataset.task == "segmentation":
self.log_dict(self.pixel_metrics)
2 changes: 1 addition & 1 deletion anomalib/core/model/kde.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
features = torch.matmul(features, self.bw_transform)

estimate = torch.zeros(features.shape[0])
estimate = torch.zeros(features.shape[0]).to(features.device)
for i in range(features.shape[0]):
embedding = ((self.dataset - features[i]) ** 2).sum(dim=1)
embedding = torch.exp(-embedding / 2) * self.norm
Expand Down
19 changes: 0 additions & 19 deletions anomalib/core/results/__init__.py

This file was deleted.

110 changes: 0 additions & 110 deletions anomalib/core/results/results.py

This file was deleted.

Loading