Skip to content

Commit

Permalink
Revert "Updated coreset subsampling method to improve accuracy (#73)" (
Browse files Browse the repository at this point in the history
…#79)

This reverts commit 3613a6e.
  • Loading branch information
samet-akcay committed Jan 18, 2022
1 parent 3613a6e commit 2575a51
Show file tree
Hide file tree
Showing 9 changed files with 435 additions and 88 deletions.
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.1
coreset_sampling_ratio: 0.001
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
74 changes: 28 additions & 46 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import torchvision
from kornia import gaussian_blur2d
from omegaconf import ListConfig
from sklearn import random_projection
from torch import Tensor, nn

from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -165,8 +170,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=9, largest=False, dim=1)
patch_scores, _ = self.nn_search.kneighbors(embedding)

anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)
Expand Down Expand Up @@ -209,48 +213,25 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

def create_coreset(
self,
embedding: Tensor,
sample_count: int = 500,
eps: float = 0.90,
):
"""Creates n subsampled coreset for given sample_set.
@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.
Args:
embedding (Tensor): (sample_count, d) tensor of patches.
sample_count (int): Number of patches to select.
eps (float): Parameter for spare projection aggression.
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio
Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
"""
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print("Fitting random projections...")
try:
transformer = random_projection.SparseRandomProjection(eps=eps)
sample_set = torch.tensor(transformer.fit_transform(embedding.cpu())).to( # pylint: disable=not-callable
embedding.device
)
except ValueError:
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print(" Error: could not project vectors. Please increase `eps` value.")

select_idx = 0
last_item = sample_set[select_idx : select_idx + 1]
coreset_idx = [torch.tensor(select_idx).to(embedding.device)] # pylint: disable=not-callable
min_distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True)

for _ in range(sample_count - 1):
distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True) # broadcast
min_distances = torch.minimum(distances, min_distances) # iterate
select_idx = torch.argmax(min_distances) # select

last_item = sample_set[select_idx : select_idx + 1]
min_distances[select_idx] = 0
coreset_idx.append(select_idx)

coreset_idx = torch.stack(coreset_idx)
self.memory_bank = embedding[coreset_idx]
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset


class PatchcoreLightning(AnomalyModule):
Expand Down Expand Up @@ -311,10 +292,12 @@ def training_epoch_end(self, outputs):
outputs (List[Dict[str, np.ndarray]]): List of embedding vectors
"""
embedding = torch.vstack([output["embedding"] for output in outputs])

sampling_ratio = self.hparams.model.coreset_sampling_ratio

self.model.create_coreset(embedding=embedding, sample_count=int(sampling_ratio * embedding.shape[0]), eps=0.9)
embedding = self.model.subsample_embedding(embedding, sampling_ratio)

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -328,8 +311,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, anomaly_score = self.model(batch["image"])
anomaly_maps, _ = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 1 addition & 0 deletions anomalib/models/patchcore/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Helper utilities for PatchCore model."""
7 changes: 7 additions & 0 deletions anomalib/models/patchcore/utils/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Patchcore sampling utils."""

from .k_center_greedy import KCenterGreedy
from .nearest_neighbors import NearestNeighbors
from .random_projection import SparseRandomProjection

__all__ = ["KCenterGreedy", "NearestNeighbors", "SparseRandomProjection"]
152 changes: 152 additions & 0 deletions anomalib/models/patchcore/utils/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""This module comprises PatchCore Sampling Methods for the embedding.
- k Center Greedy Method
Returns points that minimizes the maximum distance of any point to a center.
. https://arxiv.org/abs/1708.00489
"""

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from .random_projection import SparseRandomProjection


class KCenterGreedy:
"""Implements k-center-greedy method.
Args:
model: model with scikit-like API with decision_function. Defaults to SparseRandomProjection.
embedding (Tensor): Embedding vector extracted from a CNN
sampling_ratio (float): Ratio to choose coreset size from the embedding size.
Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(embedding=embedding)
>>> sampled_idxs = sampler.select_coreset_idxs()
>>> coreset = embedding[sampled_idxs]
>>> coreset.shape
torch.Size([219, 1536])
"""

def __init__(self, model: SparseRandomProjection, embedding: Tensor, sampling_ratio: float) -> None:
self.model = model
self.embedding = embedding
self.coreset_size = int(embedding.shape[0] * sampling_ratio)

self.features: Tensor
self.min_distances: Optional[Tensor] = None
self.n_observations = self.embedding.shape[0]
self.already_selected_idxs: List[int] = []

def reset_distances(self) -> None:
"""Reset minimum distances."""
self.min_distances = None

def get_new_cluster_centers(self, cluster_centers: List[int]) -> List[int]:
"""Get new cluster center indexes from the list of cluster indexes.
Args:
cluster_centers (List[int]): List of cluster center indexes.
Returns:
List[int]: List of new cluster center indexes.
"""
return [d for d in cluster_centers if d not in self.already_selected_idxs]

def update_distances(self, cluster_centers: List[int]) -> None:
"""Update min distances given cluster centers.
Args:
cluster_centers (List[int]): indices of cluster centers
"""

if cluster_centers:
cluster_centers = self.get_new_cluster_centers(cluster_centers)
centers = self.features[cluster_centers]

distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)

if self.min_distances is None:
self.min_distances = torch.min(distance, dim=1).values.reshape(-1, 1)
else:
self.min_distances = torch.minimum(self.min_distances, distance)

def get_new_idx(self) -> int:
"""Get index value of a sample.
Based on (i) either minimum distance of the cluster or (ii) random subsampling from the embedding.
Returns:
int: Sample index
"""

if self.already_selected_idxs is None or len(self.already_selected_idxs) == 0:
# Initialize centers with a randomly selected datapoint
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
else:
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")

return idx

def select_coreset_idxs(self, selected_idxs: Optional[List[int]] = None) -> List[int]:
"""Greedily form a coreset to minimize the maximum distance of a cluster.
Args:
selected_idxs: index of samples already selected. Defaults to an empty set.
Returns:
indices of samples selected to minimize distance to cluster centers
"""

if selected_idxs is None:
selected_idxs = []

if self.embedding.ndim == 2:
self.features = self.model.transform(self.embedding)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)

selected_coreset_idxs: List[int] = []
for _ in range(self.coreset_size):
idx = self.get_new_idx()
if idx in selected_idxs:
raise ValueError("New indices should not be in selected indices.")

self.update_distances(cluster_centers=[idx])
selected_coreset_idxs.append(idx)

self.already_selected_idxs = selected_idxs

return selected_coreset_idxs

def sample_coreset(self, selected_idxs: Optional[List[int]] = None) -> Tensor:
"""Select coreset from the embedding.
Args:
selected_idxs: index of samples already selected. Defaults to an empty set.
Returns:
Tensor: Output coreset
Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(...)
>>> coreset = sampler.sample_coreset()
>>> coreset.shape
torch.Size([219, 1536])
"""

idxs = self.select_coreset_idxs(selected_idxs)
coreset = self.embedding[idxs]

return coreset
62 changes: 62 additions & 0 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""This module comprises PatchCore Sampling Methods for the embedding.
- Nearest Neighbours
"""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Tuple

import torch
from torch import Tensor

from anomalib.core.model.dynamic_module import DynamicBufferModule


class NearestNeighbors(DynamicBufferModule):
"""Nearest Neighbours using brute force method and euclidean norm.
Args:
n_neighbors (int): Number of neighbors to look at
"""

def __init__(self, n_neighbors: int):
super().__init__()
self.n_neighbors = n_neighbors

self.register_buffer("_fit_x", Tensor())
self._fit_x: Tensor

def fit(self, train_features: Tensor):
"""Saves the train features for NN search later.
Args:
train_features (Tensor): Training data
"""
self._fit_x = train_features

def kneighbors(self, test_features: Tensor) -> Tuple[Tensor, Tensor]:
"""Return k-nearest neighbors.
It is calculated based on bruteforce method.
Args:
test_features (Tensor): test data
Returns:
Tuple[Tensor, Tensor]: distances, indices
"""
distances = torch.cdist(test_features, self._fit_x, p=2.0) # euclidean norm
return distances.topk(k=self.n_neighbors, largest=False, dim=1)
Loading

0 comments on commit 2575a51

Please sign in to comment.