Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PRO metric #508

Merged
merged 9 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion anomalib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
from .collection import AnomalibMetricCollection
from .min_max import MinMax
from .optimal_f1 import OptimalF1
from .pro import PRO

__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]


def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:
Expand Down
13 changes: 9 additions & 4 deletions anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from typing import Any, Callable, List, Optional, Tuple

import torch
from kornia.contrib import connected_components
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import auc, roc
from torchmetrics.utilities.data import dim_zero_cat

from anomalib.utils.metrics.pro import (
connected_components_cpu,
connected_components_gpu,
)

from .plotting_utils import plot_figure


Expand Down Expand Up @@ -80,9 +84,10 @@ def _compute(self) -> Tuple[Tensor, Tensor]:
)
target = target.unsqueeze(1) # kornia expects N1HW format
target = target.type(torch.float) # kornia expects FloatTensor
cca = connected_components(
target, num_iterations=1000
) # Need higher thresholds this to avoid oversegmentation.
if target.is_cuda:
cca = connected_components_gpu(target)
else:
cca = connected_components_cpu(target)

preds = preds.flatten()
cca = cca.flatten()
Expand Down
112 changes: 112 additions & 0 deletions anomalib/utils/metrics/pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Implementation of PRO metric based on TorchMetrics."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List
djdameln marked this conversation as resolved.
Show resolved Hide resolved

import cv2
import numpy as np
import torch
from kornia.contrib import connected_components
from torch import Tensor
from torchmetrics import Metric
from torchmetrics.functional import recall
from torchmetrics.utilities.data import dim_zero_cat


class PRO(Metric):
"""Per-Region Overlap (PRO) Score."""

target: List[Tensor]
preds: List[Tensor]

def __init__(self, threshold: float = 0.5, **kwargs) -> None:
super().__init__(**kwargs)
self.threshold = threshold

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, predictions: Tensor, targets: Tensor) -> None:
"""Compute the PRO score for the current batch."""

self.target.append(targets)
self.preds.append(predictions)

def compute(self) -> Tensor:
"""Compute the macro average of the PRO score across all regions in all batches."""
target = dim_zero_cat(self.target)
preds = dim_zero_cat(self.preds)

if target.is_cuda:
comps = connected_components_gpu(target.unsqueeze(1))
else:
comps = connected_components_cpu(target.unsqueeze(1))
pro = pro_score(preds, comps, threshold=self.threshold)
return pro


def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Tensor:
"""Calculate the PRO score for a batch of predictions.

Args:
predictions (Tensor): Predicted anomaly masks (Bx1xHxW)
comps: (Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N
threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions.

Returns:
Tensor: Scalar value representing the average PRO score for the input batch.
"""
if predictions.dtype == torch.float:
predictions = predictions > threshold

n_comps = len(comps.unique())

preds = comps.clone()
preds[~predictions] = 0
if n_comps == 1: # only background
return torch.Tensor([1.0])
pro = recall(preds.flatten(), comps.flatten(), num_classes=n_comps, average="macro", ignore_index=0)
return pro


def connected_components_gpu(binary_input: Tensor, num_iterations: int = 1000) -> Tensor:
"""Perform connected component labeling on GPU and remap the labels from 0 to N.

Args:
binary_input (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
num_iterations (int): Number of iterations used in the connected component computation.

Returns:
Tensor: Components labeled from 0 to N.
"""
components = connected_components(binary_input, num_iterations=num_iterations)

# remap component values from 0 to N
labels = components.unique()
for new_label, old_label in enumerate(labels):
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
components[components == old_label] = new_label

return components.int()


def connected_components_cpu(image: Tensor) -> Tensor:
"""Connected component labeling on CPU.

Args:
image (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)

Returns:
Tensor: Components labeled from 0 to N.
"""
components = torch.zeros_like(image)
label_idx = 1
for i, mask in enumerate(image):
mask = mask.squeeze().numpy().astype(np.uint8)
_, comps = cv2.connectedComponents(mask)
# remap component values to make sure every component has a unique value when outputs are concatenated
for label in np.unique(comps)[1:]:
components[i, 0, ...][np.where(comps == label)] = label_idx
label_idx += 1
return components.int()
12 changes: 4 additions & 8 deletions tests/pre_merge/utils/metrics/test_aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,17 @@ def pytest_generate_tests(metafunc):
torch.tensor(
[
[
[
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
[0, 0, 0, 1, 0, 0, 0],
]
* 400,
]
),
torch.tensor(
[
[
[
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
[0, 1, 0, 1, 0, 1, 0],
]
* 400,
]
),
]
Expand Down
80 changes: 80 additions & 0 deletions tests/pre_merge/utils/metrics/test_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
from torch import Tensor
from torchvision.transforms import RandomAffine

from anomalib.data.utils import random_2d_perlin
from anomalib.utils.metrics.pro import (
PRO,
connected_components_cpu,
connected_components_gpu,
)


def test_pro():
"""Checks if PRO metric computes the (macro) average of the per-region overlap."""

labels = Tensor(
[
[
[0, 0, 0, 0, 0],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 0, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 0],
[0, 0, 0, 0, 0],
[1, 1, 1, 1, 1],
]
]
)

preds = (torch.arange(10) / 10) + 0.05
preds = preds.unsqueeze(1).repeat(1, 5).view(1, 1, 10, 5)

thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
targets = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]
for threshold, target in zip(thresholds, targets):
pro = PRO(threshold=threshold)
pro.update(preds, labels)
assert pro.compute() == target


def test_device_consistency():
"""Test if the pro metric yields the same results between cpu and gpu."""

transform = RandomAffine(5, None, (0.95, 1.05), 5)

batch = torch.zeros((32, 256, 256))
for i in range(batch.shape[0]):
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5

preds = transform(batch).unsqueeze(1)

pro_cpu = PRO()
pro_gpu = PRO()

pro_cpu.update(preds.cpu(), batch.cpu())
pro_gpu.update(preds.cuda(), batch.cuda())

assert torch.isclose(pro_cpu.compute(), pro_gpu.compute().cpu())


def test_connected_component_labeling():
"""Tests if the connected component labeling algorithms on cpu and gpu yield the same result."""

# generate batch of random binary images using perlin noise
batch = torch.zeros((32, 1, 256, 256))
for i in range(batch.shape[0]):
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5

# get connected component results on both cpu and gpu
cc_cpu = connected_components_cpu(batch.cpu())
cc_gpu = connected_components_gpu(batch.cuda())

# check if comps are ordered from 0 to N
assert len(cc_cpu.unique()) == cc_cpu.unique().max() + 1
assert len(cc_gpu.unique()) == cc_gpu.unique().max() + 1
# check if same number of comps found between cpu and gpu
assert len(cc_cpu.unique()) == len(cc_gpu.unique())