Skip to content

Commit

Permalink
Fixed issue with k_greedy method (#80)
Browse files Browse the repository at this point in the history
* Fixed issue with k_greedy method - removed nearest_neighbor and sparse_prection methods

* reverting back to torch implementation of SparseRandomProjection

* annotation

* black

Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
blakshma and samet-akcay committed Jan 24, 2022
1 parent f17540c commit 785feb2
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 171 deletions.
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,13 @@
import torch.nn.functional as F
from torch import Tensor

from .random_projection import SparseRandomProjection
from anomalib.core.model.random_projection import SparseRandomProjection


class KCenterGreedy:
"""Implements k-center-greedy method.
Args:
model: model with scikit-like API with decision_function. Defaults to SparseRandomProjection.
embedding (Tensor): Embedding vector extracted from a CNN
sampling_ratio (float): Ratio to choose coreset size from the embedding size.
Expand All @@ -32,31 +31,19 @@ class KCenterGreedy:
torch.Size([219, 1536])
"""

def __init__(self, model: SparseRandomProjection, embedding: Tensor, sampling_ratio: float) -> None:
self.model = model
def __init__(self, embedding: Tensor, sampling_ratio: float) -> None:
self.embedding = embedding
self.coreset_size = int(embedding.shape[0] * sampling_ratio)
self.model = SparseRandomProjection(eps=0.9)

self.features: Tensor
self.min_distances: Optional[Tensor] = None
self.min_distances: Tensor = None
self.n_observations = self.embedding.shape[0]
self.already_selected_idxs: List[int] = []

def reset_distances(self) -> None:
"""Reset minimum distances."""
self.min_distances = None

def get_new_cluster_centers(self, cluster_centers: List[int]) -> List[int]:
"""Get new cluster center indexes from the list of cluster indexes.
Args:
cluster_centers (List[int]): List of cluster center indexes.
Returns:
List[int]: List of new cluster center indexes.
"""
return [d for d in cluster_centers if d not in self.already_selected_idxs]

def update_distances(self, cluster_centers: List[int]) -> None:
"""Update min distances given cluster centers.
Expand All @@ -65,33 +52,28 @@ def update_distances(self, cluster_centers: List[int]) -> None:
"""

if cluster_centers:
cluster_centers = self.get_new_cluster_centers(cluster_centers)
centers = self.features[cluster_centers]

distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)

if self.min_distances is None:
self.min_distances = torch.min(distance, dim=1).values.reshape(-1, 1)
self.min_distances = distance
else:
self.min_distances = torch.minimum(self.min_distances, distance)

def get_new_idx(self) -> int:
"""Get index value of a sample.
Based on (i) either minimum distance of the cluster or (ii) random subsampling from the embedding.
Based on minimum distance of the cluster
Returns:
int: Sample index
"""

if self.already_selected_idxs is None or len(self.already_selected_idxs) == 0:
# Initialize centers with a randomly selected datapoint
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")

return idx

Expand All @@ -109,23 +91,23 @@ def select_coreset_idxs(self, selected_idxs: Optional[List[int]] = None) -> List
selected_idxs = []

if self.embedding.ndim == 2:
self.model.fit(self.embedding)
self.features = self.model.transform(self.embedding)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)

selected_coreset_idxs: List[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
for _ in range(self.coreset_size):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
raise ValueError("New indices should not be in selected indices.")

self.update_distances(cluster_centers=[idx])
self.min_distances[idx] = 0
selected_coreset_idxs.append(idx)

self.already_selected_idxs = selected_idxs

return selected_coreset_idxs

def sample_coreset(self, selected_idxs: Optional[List[int]] = None) -> Tensor:
Expand Down
File renamed without changes.
12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.001
coreset_sampling_ratio: 0.1
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
55 changes: 27 additions & 28 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@
from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.core.model.k_center_greedy import KCenterGreedy
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand All @@ -44,7 +40,7 @@ def __init__(
self,
input_size: Union[ListConfig, Tuple],
sigma: int = 4,
):
) -> None:
self.input_size = input_size
self.sigma = sigma

Expand Down Expand Up @@ -117,7 +113,7 @@ def __init__(
apply_tiling: bool = False,
tile_size: Optional[Tuple[int, int]] = None,
tile_stride: Optional[int] = None,
):
) -> None:
super().__init__()

self.backbone = getattr(torchvision.models, backbone)
Expand All @@ -127,7 +123,6 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -170,8 +165,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
patch_scores, _ = self.nn_search.kneighbors(embedding)

patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=9)
anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)

Expand Down Expand Up @@ -213,25 +207,32 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.
def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None:
"""Subsample embedding based on coreset sampling and store to memory.
Args:
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio
Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
"""
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset
self.memory_bank = coreset

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor:
"""Nearest Neighbours using brute force method and euclidean norm.
Args:
embedding (Tensor): Features to compare the distance with the memory bank.
n_neighbors (int): Number of neighbors to look at
Returns:
Tensor: Patch scores.
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores


class PatchcoreLightning(AnomalyModule):
Expand All @@ -246,7 +247,7 @@ class PatchcoreLightning(AnomalyModule):
apply_tiling (bool, optional): Apply tiling. Defaults to False.
"""

def __init__(self, hparams):
def __init__(self, hparams) -> None:
super().__init__(hparams)

self.model = PatchcoreModel(
Expand All @@ -259,7 +260,7 @@ def __init__(self, hparams):
)
self.automatic_optimization = False

def configure_optimizers(self):
def configure_optimizers(self) -> None:
"""Configure optimizers.
Returns:
Expand Down Expand Up @@ -294,10 +295,7 @@ def training_epoch_end(self, outputs):
embedding = torch.vstack([output["embedding"] for output in outputs])
sampling_ratio = self.hparams.model.coreset_sampling_ratio

embedding = self.model.subsample_embedding(embedding, sampling_ratio)

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding
self.model.subsample_embedding(embedding, sampling_ratio)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -311,7 +309,8 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, _ = self.model(batch["image"])
anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 0 additions & 1 deletion anomalib/models/patchcore/utils/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions anomalib/models/patchcore/utils/sampling/__init__.py

This file was deleted.

62 changes: 0 additions & 62 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py

This file was deleted.

0 comments on commit 785feb2

Please sign in to comment.