Skip to content

Commit

Permalink
🧪 Add unit-tests for AUPRO metric (#444)
Browse files Browse the repository at this point in the history
* Bugfix in aupro calculation

We shouldn't regard other regions as FPs/TNs during per-region
roc-construction

* Add tests for aupro, catch edge-cases

* Bugfixes in AUPRO, test against offical reference

* Beautify code, add license for aupro-reference
  • Loading branch information
ORippler committed Jul 20, 2022
1 parent c826c71 commit a946140
Show file tree
Hide file tree
Showing 5 changed files with 370 additions and 6 deletions.
40 changes: 34 additions & 6 deletions anomalib/utils/metrics/aupro.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(

self.add_state("preds", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.add_state("target", default=[], dist_reduce_fx="cat") # pylint: disable=not-callable
self.fpr_limit = fpr_limit
self.register_buffer("fpr_limit", torch.tensor(fpr_limit))

def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with new values.
Expand Down Expand Up @@ -76,7 +76,9 @@ def _compute(self) -> Tuple[Tensor, Tensor]:
)
target = target.unsqueeze(1) # kornia expects N1HW format
target = target.type(torch.float) # kornia expects FloatTensor
cca = connected_components(target)
cca = connected_components(
target, num_iterations=1000
) # Need higher thresholds this to avoid oversegmentation.

preds = preds.flatten()
cca = cca.flatten()
Expand All @@ -89,24 +91,50 @@ def _compute(self) -> Tuple[Tensor, Tensor]:
# compute the PRO curve by aggregating per-region tpr/fpr curves/values.
tpr = torch.zeros(output_size, device=preds.device, dtype=torch.float)
fpr = torch.zeros(output_size, device=preds.device, dtype=torch.float)
new_idx = torch.arange(0, output_size, device=preds.device)
new_idx = torch.arange(0, output_size, device=preds.device, dtype=torch.float)

# Loop over the labels, computing per-region tpr/fpr curves, and aggregating them.
# Note that, since the groundtruth is different for every all to `roc`, we also get
# different/unique tpr/fpr curves (i.e. len(_fpr_idx) is different for every call).
# We therefore need to resample per-region curves to a fixed sampling ratio (defined above).
labels = cca.unique()[1:] # 0 is background
background = cca == 0
_fpr: Tensor
_tpr: Tensor
for label in labels:
interp: bool = False
new_idx[-1] = output_size - 1
mask = cca == label
_fpr, _tpr = roc(preds, mask)[:-1] # don't need threshs
_fpr_idx = torch.where(_fpr <= self.fpr_limit)[0]
# Need to calculate label-wise roc on union of background & mask, as otherwise we wrongly consider other
# label in labels as FPs. We also don't need to return the thresholds
_fpr, _tpr = roc(preds[background | mask], mask[background | mask])[:-1]

# catch edge-case where ROC only has fpr vals > self.fpr_limit
if _fpr[_fpr <= self.fpr_limit].max() == 0:
_fpr_limit = _fpr[_fpr > self.fpr_limit].min()
else:
_fpr_limit = self.fpr_limit

_fpr_idx = torch.where(_fpr <= _fpr_limit)[0]
# if computed roc curve is not specified sufficiently close to self.fpr_limit,
# we include the closest higher tpr/fpr pair and linearly interpolate the tpr/fpr point at self.fpr_limit
if not torch.allclose(_fpr[_fpr_idx].max(), self.fpr_limit):
_tmp_idx = torch.searchsorted(_fpr, self.fpr_limit)
_fpr_idx = torch.cat([_fpr_idx, _tmp_idx.unsqueeze_(0)])
_slope = 1 - ((_fpr[_tmp_idx] - self.fpr_limit) / (_fpr[_tmp_idx] - _fpr[_tmp_idx - 1]))
interp = True

_fpr = _fpr[_fpr_idx]
_tpr = _tpr[_fpr_idx]

_fpr_idx = _fpr_idx.float()
_fpr_idx /= _fpr_idx.max()
_fpr_idx *= new_idx.max()

if interp:
# last point will be sampled at self.fpr_limit
new_idx[-1] = _fpr_idx[-2] + ((_fpr_idx[-1] - _fpr_idx[-2]) * _slope)

_tpr = self.interp1d(_fpr_idx, _tpr, new_idx)
_fpr = self.interp1d(_fpr_idx, _fpr, new_idx)
tpr += _tpr
Expand Down Expand Up @@ -139,7 +167,7 @@ def generate_figure(self) -> Tuple[Figure, str]:
fpr, tpr = self._compute()
aupro = self.compute()

xlim = (0.0, self.fpr_limit)
xlim = (0.0, self.fpr_limit.detach_().cpu().numpy())
ylim = (0.0, 1.0)
xlabel = "Global FPR"
ylabel = "Averaged Per-Region TPR"
Expand Down
47 changes: 47 additions & 0 deletions tests/helpers/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
Copyright (c) 2022 Intel Corporation
SPDX-License-Identifier: Apache-2.0

The file `aupro_reference.py` in this folder is based on code from Eliahu Horwitz and Yedid Hoshen,
which itself is based on code from MVTec Software GmbH.

Original license:
----------------

MIT License

Copyright (c) 2022 Eliahu Horwitz and Yedid Hoshen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Footer


Original license:
----------------

Copyright 2021 MVTec Software GmbH

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
221 changes: 221 additions & 0 deletions tests/helpers/aupro_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# REMARK: CODE WAS TAKEN FROM https://github.com/eliahuhorwitz/3D-ADS/blob/main/utils/au_pro_util.py

"""
Code based on the official MVTec 3D-AD evaluation code found at
https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz
Utility functions that compute a PRO curve and its definite integral, given
pairs of anomaly and ground truth maps.
The PRO curve can also be integrated up to a constant integration limit.
"""
from bisect import bisect

import numpy as np
from scipy.ndimage.measurements import label


class GroundTruthComponent:
"""
Stores sorted anomaly scores of a single ground truth component.
Used to efficiently compute the region overlap for many increasing thresholds.
"""

def __init__(self, anomaly_scores):
"""
Initialize the module.
Args:
anomaly_scores: List of all anomaly scores within the ground truth
component as numpy array.
"""
# Keep a sorted list of all anomaly scores within the component.
self.anomaly_scores = anomaly_scores.copy()
self.anomaly_scores.sort()

# Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels.
self.index = 0

# The last evaluated threshold.
self.last_threshold = None

def compute_overlap(self, threshold):
"""
Compute the region overlap for a specific threshold.
Thresholds must be passed in increasing order.
Args:
threshold: Threshold to compute the region overlap.
Returns:
Region overlap for the specified threshold.
"""
if self.last_threshold is not None:
assert self.last_threshold <= threshold

# Increase the index until it points to an anomaly score that is just above the specified threshold.
while self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold:
self.index += 1

# Compute the fraction of component pixels that are correctly segmented as anomalous.
return 1.0 - self.index / len(self.anomaly_scores)


def trapezoid(x, y, x_max=None):
"""
This function calculates the definit integral of a curve given by x- and corresponding y-values.
In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by
setting a value x_max.
Points that do not have a finite x or y value will be ignored with a warning.
Args:
x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain
the same value multiple times. In that case, the order of the corresponding y values will affect the
integration with the trapezoidal rule.
y: Values of the function corresponding to x values.
x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its
neighbors. Must not lie outside of the range of x.
Returns:
Area under the curve.
"""

x = np.array(x)
y = np.array(y)
finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
if not finite_mask.all():
print(
"""WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values."""
)
x = x[finite_mask]
y = y[finite_mask]

# Introduce a correction term if max_x is not an element of x.
correction = 0.0
if x_max is not None:
if x_max not in x:
# Get the insertion index that would keep x sorted after np.insert(x, ins, x_max).
ins = bisect(x, x_max)
# x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x).
assert 0 < ins < len(x)

# Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not
# know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1].
y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1]))
correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])

# Cut off at x_max.
mask = x <= x_max
x = x[mask]
y = y[mask]

# Return area under the curve using the trapezoidal rule.
return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction


def collect_anomaly_scores(anomaly_maps, ground_truth_maps):
"""
Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false
positive pixel from anomaly maps.
Args:
anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
an anomaly.
Returns:
ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each
component, a sorted list of its anomaly scores is stored.
anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list
can be used to quickly select thresholds that fix a certain false positive rate.
"""
# Make sure an anomaly map is present for each ground truth map.
assert len(anomaly_maps) == len(ground_truth_maps)

# Initialize ground truth components and scores of potential fp pixels.
ground_truth_components = []
anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size)

# Structuring element for computing connected components.
structure = np.ones((3, 3), dtype=int)

# Collect anomaly scores within each ground truth region and for all potential fp pixels.
ok_index = 0
for gt_map, prediction in zip(ground_truth_maps, anomaly_maps):

# Compute the connected components in the ground truth map.
labeled, n_components = label(gt_map, structure)

# Store all potential fp scores.
num_ok_pixels = len(prediction[labeled == 0])
anomaly_scores_ok_pixels[ok_index : ok_index + num_ok_pixels] = prediction[labeled == 0].copy()
ok_index += num_ok_pixels

# Fetch anomaly scores within each GT component.
for k in range(n_components):
component_scores = prediction[labeled == (k + 1)]
ground_truth_components.append(GroundTruthComponent(component_scores))

# Sort all potential false positive scores.
anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index)
anomaly_scores_ok_pixels.sort()

return ground_truth_components, anomaly_scores_ok_pixels


def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):
"""
Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground
truth maps. The number of interpolation points can be set manually.
Args:
anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
an anomaly.
num_thresholds: Number of thresholds to compute the PRO curve.
Returns:
fprs: List of false positive rates.
pros: List of correspoding PRO values.
"""
# Fetch sorted anomaly scores.
ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps)
# Select equidistant thresholds.
threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int)

fprs = [1.0]
pros = [1.0]
for pos in threshold_positions:
threshold = anomaly_scores_ok_pixels[pos]

# Compute the false positive rate for this threshold.
fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels)

# Compute the PRO value for this threshold.
pro = 0.0
for component in ground_truth_components:
pro += component.compute_overlap(threshold)
pro /= len(ground_truth_components)

fprs.append(fpr)
pros.append(pro)

# Return (FPR/PRO) pairs in increasing FPR order.
fprs = fprs[::-1]
pros = pros[::-1]

return fprs, pros


def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100):
"""
Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images.
Args:
gts: List of tensors that contain the ground truth images for a single dataset object.
predictions: List of tensors containing anomaly images for each ground truth image.
integration_limit: Integration limit to use when computing the area under the PRO curve.
num_thresholds: Number of thresholds to use to sample the area under the PRO curve.
Returns:
au_pro: Area under the PRO curve computed up to the given integration limit.
pro_curve: PRO curve values for localization (fpr,pro).
"""
# Compute the PRO curve.
pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds)

# Compute the area under the PRO curve.
au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit)
au_pro /= integration_limit

# Return the evaluation metrics.
return au_pro, pro_curve
Loading

0 comments on commit a946140

Please sign in to comment.