Skip to content

Commit

Permalink
📏 Add PRO metric (#508)
Browse files Browse the repository at this point in the history
* remove cpu warning

* add pro metric test

* remove pylint ignore statements

* fix component labeling bug

* add more tests for pro metric and ccomp labeling

* use kornia for ccomp labeling

* fix aupro tests

* address PR comments
  • Loading branch information
djdameln committed Aug 31, 2022
1 parent a03e592 commit bd36919
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 13 deletions.
3 changes: 2 additions & 1 deletion anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1
from .pro import PRO

__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
Expand Down
13 changes: 9 additions & 4 deletions anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from typing import Any, Callable, List, Optional, Tuple

import torch
from kornia.contrib import connected_components
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import auc, roc
from torchmetrics.utilities.data import dim_zero_cat

from anomalib.utils.metrics.pro import (
connected_components_cpu,
connected_components_gpu,
)

from .plotting_utils import plot_figure


Expand Down Expand Up @@ -80,9 +84,10 @@ def _compute(self) -> Tuple[Tensor, Tensor]:
)
target = target.unsqueeze(1) # kornia expects N1HW format
target = target.type(torch.float) # kornia expects FloatTensor
cca = connected_components(
target, num_iterations=1000
) # Need higher thresholds this to avoid oversegmentation.
if target.is_cuda:
cca = connected_components_gpu(target)
else:
cca = connected_components_cpu(target)

preds = preds.flatten()
cca = cca.flatten()
Expand Down
112 changes: 112 additions & 0 deletions anomalib/utils/metrics/pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Implementation of PRO metric based on TorchMetrics."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List

import cv2
import numpy as np
import torch
from kornia.contrib import connected_components
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import recall
from torchmetrics.utilities.data import dim_zero_cat


class PRO(Metric):
"""Per-Region Overlap (PRO) Score."""

target: List[Tensor]
preds: List[Tensor]

def __init__(self, threshold: float = 0.5, **kwargs) -> None:
super().__init__(**kwargs)
self.threshold = threshold

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, predictions: Tensor, targets: Tensor) -> None:
"""Compute the PRO score for the current batch."""

self.target.append(targets)
self.preds.append(predictions)

def compute(self) -> Tensor:
"""Compute the macro average of the PRO score across all regions in all batches."""
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

if target.is_cuda:
comps = connected_components_gpu(target.unsqueeze(1))
else:
comps = connected_components_cpu(target.unsqueeze(1))
pro = pro_score(preds, comps, threshold=self.threshold)
return pro


def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Tensor:
"""Calculate the PRO score for a batch of predictions.
Args:
predictions (Tensor): Predicted anomaly masks (Bx1xHxW)
comps: (Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N
threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions.
Returns:
Tensor: Scalar value representing the average PRO score for the input batch.
"""
if predictions.dtype == torch.float:
predictions = predictions > threshold

n_comps = len(comps.unique())

preds = comps.clone()
preds[~predictions] = 0
if n_comps == 1: # only background
return torch.Tensor([1.0])
pro = recall(preds.flatten(), comps.flatten(), num_classes=n_comps, average="macro", ignore_index=0)
return pro


def connected_components_gpu(binary_input: Tensor, num_iterations: int = 1000) -> Tensor:
"""Perform connected component labeling on GPU and remap the labels from 0 to N.
Args:
binary_input (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
num_iterations (int): Number of iterations used in the connected component computation.
Returns:
Tensor: Components labeled from 0 to N.
"""
components = connected_components(binary_input, num_iterations=num_iterations)

# remap component values from 0 to N
labels = components.unique()
for new_label, old_label in enumerate(labels):
components[components == old_label] = new_label

return components.int()


def connected_components_cpu(image: Tensor) -> Tensor:
"""Connected component labeling on CPU.
Args:
image (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
Returns:
Tensor: Components labeled from 0 to N.
"""
components = torch.zeros_like(image)
label_idx = 1
for i, mask in enumerate(image):
mask = mask.squeeze().numpy().astype(np.uint8)
_, comps = cv2.connectedComponents(mask)
# remap component values to make sure every component has a unique value when outputs are concatenated
for label in np.unique(comps)[1:]:
components[i, 0, ...][np.where(comps == label)] = label_idx
label_idx += 1
return components.int()
12 changes: 4 additions & 8 deletions tests/pre_merge/utils/metrics/test_aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,17 @@ def pytest_generate_tests(metafunc):
torch.tensor(
[
[
[
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
]
),
torch.tensor(
[
[
[
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
]
),
]
Expand Down
80 changes: 80 additions & 0 deletions tests/pre_merge/utils/metrics/test_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torch import Tensor
from torchvision.transforms import RandomAffine

from anomalib.data.utils import random_2d_perlin
from anomalib.utils.metrics.pro import (
PRO,
connected_components_cpu,
connected_components_gpu,
)


def test_pro():
"""Checks if PRO metric computes the (macro) average of the per-region overlap."""

labels = Tensor(
[
[
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
]
]
)

preds = (torch.arange(10) / 10) + 0.05
preds = preds.unsqueeze(1).repeat(1, 5).view(1, 1, 10, 5)

thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
targets = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]
for threshold, target in zip(thresholds, targets):
pro = PRO(threshold=threshold)
pro.update(preds, labels)
assert pro.compute() == target


def test_device_consistency():
"""Test if the pro metric yields the same results between cpu and gpu."""

transform = RandomAffine(5, None, (0.95, 1.05), 5)

batch = torch.zeros((32, 256, 256))
for i in range(batch.shape[0]):
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5

preds = transform(batch).unsqueeze(1)

pro_cpu = PRO()
pro_gpu = PRO()

pro_cpu.update(preds.cpu(), batch.cpu())
pro_gpu.update(preds.cuda(), batch.cuda())

assert torch.isclose(pro_cpu.compute(), pro_gpu.compute().cpu())


def test_connected_component_labeling():
"""Tests if the connected component labeling algorithms on cpu and gpu yield the same result."""

# generate batch of random binary images using perlin noise
batch = torch.zeros((32, 1, 256, 256))
for i in range(batch.shape[0]):
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5

# get connected component results on both cpu and gpu
cc_cpu = connected_components_cpu(batch.cpu())
cc_gpu = connected_components_gpu(batch.cuda())

# check if comps are ordered from 0 to N
assert len(cc_cpu.unique()) == cc_cpu.unique().max() + 1
assert len(cc_gpu.unique()) == cc_gpu.unique().max() + 1
# check if same number of comps found between cpu and gpu
assert len(cc_cpu.unique()) == len(cc_gpu.unique())

0 comments on commit bd36919

Please sign in to comment.