Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed issue with k_greedy method #80

Merged
merged 5 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,14 @@

import torch
import torch.nn.functional as F
from sklearn import random_projection
from torch import Tensor

from .random_projection import SparseRandomProjection


class KCenterGreedy:
"""Implements k-center-greedy method.

Args:
model: model with scikit-like API with decision_function. Defaults to SparseRandomProjection.
embedding (Tensor): Embedding vector extracted from a CNN
sampling_ratio (float): Ratio to choose coreset size from the embedding size.

Expand All @@ -32,31 +30,19 @@ class KCenterGreedy:
torch.Size([219, 1536])
"""

def __init__(self, model: SparseRandomProjection, embedding: Tensor, sampling_ratio: float) -> None:
self.model = model
def __init__(self, embedding: Tensor, sampling_ratio: float) -> None:
self.embedding = embedding
self.coreset_size = int(embedding.shape[0] * sampling_ratio)
self.model = random_projection.SparseRandomProjection(eps=0.9)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

self.features: Tensor
self.min_distances: Optional[Tensor] = None
self.min_distances: Tensor = None
self.n_observations = self.embedding.shape[0]
self.already_selected_idxs: List[int] = []

def reset_distances(self) -> None:
"""Reset minimum distances."""
self.min_distances = None

def get_new_cluster_centers(self, cluster_centers: List[int]) -> List[int]:
"""Get new cluster center indexes from the list of cluster indexes.

Args:
cluster_centers (List[int]): List of cluster center indexes.

Returns:
List[int]: List of new cluster center indexes.
"""
return [d for d in cluster_centers if d not in self.already_selected_idxs]

def update_distances(self, cluster_centers: List[int]) -> None:
"""Update min distances given cluster centers.

Expand All @@ -65,33 +51,28 @@ def update_distances(self, cluster_centers: List[int]) -> None:
"""

if cluster_centers:
cluster_centers = self.get_new_cluster_centers(cluster_centers)
centers = self.features[cluster_centers]

distance = F.pairwise_distance(self.features, centers, p=2).reshape(-1, 1)

if self.min_distances is None:
self.min_distances = torch.min(distance, dim=1).values.reshape(-1, 1)
self.min_distances = distance
else:
self.min_distances = torch.minimum(self.min_distances, distance)

def get_new_idx(self) -> int:
"""Get index value of a sample.

Based on (i) either minimum distance of the cluster or (ii) random subsampling from the embedding.
Based on minimum distance of the cluster

Returns:
int: Sample index
"""

if self.already_selected_idxs is None or len(self.already_selected_idxs) == 0:
# Initialize centers with a randomly selected datapoint
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
if isinstance(self.min_distances, Tensor):
idx = int(torch.argmax(self.min_distances).item())
else:
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")
raise ValueError(f"self.min_distances must be of type Tensor. Got {type(self.min_distances)}")

return idx

Expand All @@ -109,23 +90,24 @@ def select_coreset_idxs(self, selected_idxs: Optional[List[int]] = None) -> List
selected_idxs = []

if self.embedding.ndim == 2:
self.features = self.model.transform(self.embedding)
self.features = torch.tensor( # pylint: disable=not-callable
self.model.fit_transform(self.embedding.cpu())
).to(self.embedding.device)
self.reset_distances()
else:
self.features = self.embedding.reshape(self.embedding.shape[0], -1)
self.update_distances(cluster_centers=selected_idxs)

selected_coreset_idxs: List[int] = []
idx = int(torch.randint(high=self.n_observations, size=(1,)).item())
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
for _ in range(self.coreset_size):
self.update_distances(cluster_centers=[idx])
idx = self.get_new_idx()
if idx in selected_idxs:
raise ValueError("New indices should not be in selected indices.")

self.update_distances(cluster_centers=[idx])
self.min_distances[idx] = 0
selected_coreset_idxs.append(idx)

self.already_selected_idxs = selected_idxs

return selected_coreset_idxs

def sample_coreset(self, selected_idxs: Optional[List[int]] = None) -> Tensor:
Expand Down
12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.001
coreset_sampling_ratio: 0.1
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
33 changes: 10 additions & 23 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@
from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.core.model.k_center_greedy import KCenterGreedy
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand Down Expand Up @@ -127,7 +123,6 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -170,7 +165,8 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
patch_scores, _ = self.nn_search.kneighbors(embedding)
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=9, largest=False, dim=1)

anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)
Expand Down Expand Up @@ -213,25 +209,18 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.
def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float):
"""Subsample embedding based on coreset sampling and store to memory.

Args:
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio

Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
"""
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
sampler = KCenterGreedy(embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset
self.memory_bank = coreset


class PatchcoreLightning(AnomalyModule):
Expand Down Expand Up @@ -294,10 +283,7 @@ def training_epoch_end(self, outputs):
embedding = torch.vstack([output["embedding"] for output in outputs])
sampling_ratio = self.hparams.model.coreset_sampling_ratio

embedding = self.model.subsample_embedding(embedding, sampling_ratio)

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding
self.model.subsample_embedding(embedding, sampling_ratio)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -311,7 +297,8 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, _ = self.model(batch["image"])
anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 0 additions & 1 deletion anomalib/models/patchcore/utils/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions anomalib/models/patchcore/utils/sampling/__init__.py

This file was deleted.

62 changes: 0 additions & 62 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py

This file was deleted.

Loading