Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Updated coreset subsampling method to improve accuracy" #79

Merged
merged 1 commit into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.1
coreset_sampling_ratio: 0.001
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
74 changes: 28 additions & 46 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import torchvision
from kornia import gaussian_blur2d
from omegaconf import ListConfig
from sklearn import random_projection
from torch import Tensor, nn

from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand Down Expand Up @@ -123,6 +127,7 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -165,8 +170,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=9, largest=False, dim=1)
patch_scores, _ = self.nn_search.kneighbors(embedding)

anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)
Expand Down Expand Up @@ -209,48 +213,25 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

def create_coreset(
self,
embedding: Tensor,
sample_count: int = 500,
eps: float = 0.90,
):
"""Creates n subsampled coreset for given sample_set.
@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.

Args:
embedding (Tensor): (sample_count, d) tensor of patches.
sample_count (int): Number of patches to select.
eps (float): Parameter for spare projection aggression.
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio

Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
"""
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print("Fitting random projections...")
try:
transformer = random_projection.SparseRandomProjection(eps=eps)
sample_set = torch.tensor(transformer.fit_transform(embedding.cpu())).to( # pylint: disable=not-callable
embedding.device
)
except ValueError:
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print(" Error: could not project vectors. Please increase `eps` value.")

select_idx = 0
last_item = sample_set[select_idx : select_idx + 1]
coreset_idx = [torch.tensor(select_idx).to(embedding.device)] # pylint: disable=not-callable
min_distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True)

for _ in range(sample_count - 1):
distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True) # broadcast
min_distances = torch.minimum(distances, min_distances) # iterate
select_idx = torch.argmax(min_distances) # select

last_item = sample_set[select_idx : select_idx + 1]
min_distances[select_idx] = 0
coreset_idx.append(select_idx)

coreset_idx = torch.stack(coreset_idx)
self.memory_bank = embedding[coreset_idx]
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset


class PatchcoreLightning(AnomalyModule):
Expand Down Expand Up @@ -311,10 +292,12 @@ def training_epoch_end(self, outputs):
outputs (List[Dict[str, np.ndarray]]): List of embedding vectors
"""
embedding = torch.vstack([output["embedding"] for output in outputs])

sampling_ratio = self.hparams.model.coreset_sampling_ratio

self.model.create_coreset(embedding=embedding, sample_count=int(sampling_ratio * embedding.shape[0]), eps=0.9)
embedding = self.model.subsample_embedding(embedding, sampling_ratio)

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -328,8 +311,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, anomaly_score = self.model(batch["image"])
anomaly_maps, _ = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 1 addition & 0 deletions anomalib/models/patchcore/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Helper utilities for PatchCore model."""
7 changes: 7 additions & 0 deletions anomalib/models/patchcore/utils/sampling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""Patchcore sampling utils."""

from .k_center_greedy import KCenterGreedy
from .nearest_neighbors import NearestNeighbors
from .random_projection import SparseRandomProjection

__all__ = ["KCenterGreedy", "NearestNeighbors", "SparseRandomProjection"]
152 changes: 152 additions & 0 deletions anomalib/models/patchcore/utils/sampling/k_center_greedy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""This module comprises PatchCore Sampling Methods for the embedding.

- k Center Greedy Method
Returns points that minimizes the maximum distance of any point to a center.
. https://arxiv.org/abs/1708.00489
"""

from typing import List, Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from .random_projection import SparseRandomProjection


class KCenterGreedy:
"""Implements k-center-greedy method.

Args:
model: model with scikit-like API with decision_function. Defaults to SparseRandomProjection.
embedding (Tensor): Embedding vector extracted from a CNN
sampling_ratio (float): Ratio to choose coreset size from the embedding size.

Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(embedding=embedding)
>>> sampled_idxs = sampler.select_coreset_idxs()
>>> coreset = embedding[sampled_idxs]
>>> coreset.shape
torch.Size([219, 1536])
"""

def __init__(self, model: SparseRandomProjection, embedding: Tensor, sampling_ratio: float) -> None:
self.model = model
self.embedding = embedding
self.coreset_size = int(embedding.shape[0] * sampling_ratio)

self.features: Tensor
self.min_distances: Optional[Tensor] = None
self.n_observations = self.embedding.shape[0]
self.already_selected_idxs: List[int] = []

def reset_distances(self) -> None:
"""Reset minimum distances."""
self.min_distances = None

def get_new_cluster_centers(self, cluster_centers: List[int]) -> List[int]:
"""Get new cluster center indexes from the list of cluster indexes.

Args:
cluster_centers (List[int]): List of cluster center indexes.

Returns:
List[int]: List of new cluster center indexes.
"""
return [d for d in cluster_centers if d not in self.already_selected_idxs]

def update_distances(self, cluster_centers: List[int]) -> None:
"""Update min distances given cluster centers.

Args:
cluster_centers (List[int]): indices of cluster centers
"""

if cluster_centers:
cluster_centers = self.get_new_cluster_centers(cluster_centers)
centers = self.features[cluster_centers]

distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)

if self.min_distances is None:
self.min_distances = torch.min(distance, dim=1).values.reshape(-1, 1)
else:
self.min_distances = torch.minimum(self.min_distances, distance)

def get_new_idx(self) -> int:
"""Get index value of a sample.

Based on (i) either minimum distance of the cluster or (ii) random subsampling from the embedding.

Returns:
int: Sample index
"""

if self.already_selected_idxs is None or len(self.already_selected_idxs) == 0:
# Initialize centers with a randomly selected datapoint
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
else:
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")

return idx

def select_coreset_idxs(self, selected_idxs: Optional[List[int]] = None) -> List[int]:
"""Greedily form a coreset to minimize the maximum distance of a cluster.

Args:
selected_idxs: index of samples already selected. Defaults to an empty set.

Returns:
indices of samples selected to minimize distance to cluster centers
"""

if selected_idxs is None:
selected_idxs = []

if self.embedding.ndim == 2:
self.features = self.model.transform(self.embedding)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)

selected_coreset_idxs: List[int] = []
for _ in range(self.coreset_size):
idx = self.get_new_idx()
if idx in selected_idxs:
raise ValueError("New indices should not be in selected indices.")

self.update_distances(cluster_centers=[idx])
selected_coreset_idxs.append(idx)

self.already_selected_idxs = selected_idxs

return selected_coreset_idxs

def sample_coreset(self, selected_idxs: Optional[List[int]] = None) -> Tensor:
"""Select coreset from the embedding.

Args:
selected_idxs: index of samples already selected. Defaults to an empty set.

Returns:
Tensor: Output coreset

Example:
>>> embedding.shape
torch.Size([219520, 1536])
>>> sampler = KCenterGreedy(...)
>>> coreset = sampler.sample_coreset()
>>> coreset.shape
torch.Size([219, 1536])
"""

idxs = self.select_coreset_idxs(selected_idxs)
coreset = self.embedding[idxs]

return coreset
62 changes: 62 additions & 0 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""This module comprises PatchCore Sampling Methods for the embedding.

- Nearest Neighbours
"""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Tuple

import torch
from torch import Tensor

from anomalib.core.model.dynamic_module import DynamicBufferModule


class NearestNeighbors(DynamicBufferModule):
"""Nearest Neighbours using brute force method and euclidean norm.

Args:
n_neighbors (int): Number of neighbors to look at
"""

def __init__(self, n_neighbors: int):
super().__init__()
self.n_neighbors = n_neighbors

self.register_buffer("_fit_x", Tensor())
self._fit_x: Tensor

def fit(self, train_features: Tensor):
"""Saves the train features for NN search later.

Args:
train_features (Tensor): Training data
"""
self._fit_x = train_features

def kneighbors(self, test_features: Tensor) -> Tuple[Tensor, Tensor]:
"""Return k-nearest neighbors.

It is calculated based on bruteforce method.

Args:
test_features (Tensor): test data

Returns:
Tuple[Tensor, Tensor]: distances, indices
"""
distances = torch.cdist(test_features, self._fit_x, p=2.0) # euclidean norm
return distances.topk(k=self.n_neighbors, largest=False, dim=1)
Loading