Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated coreset subsampling method to improve accuracy #73

Merged
merged 4 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.001
coreset_sampling_ratio: 0.1
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
74 changes: 46 additions & 28 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@
import torchvision
from kornia import gaussian_blur2d
from omegaconf import ListConfig
from sklearn import random_projection
from torch import Tensor, nn

from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand Down Expand Up @@ -127,7 +123,6 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -170,7 +165,8 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
patch_scores, _ = self.nn_search.kneighbors(embedding)
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=9, largest=False, dim=1)

anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)
Expand Down Expand Up @@ -213,25 +209,48 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.
def create_coreset(
self,
embedding: Tensor,
sample_count: int = 500,
eps: float = 0.90,
):
"""Creates n subsampled coreset for given sample_set.

Args:
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio

Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
embedding (Tensor): (sample_count, d) tensor of patches.
sample_count (int): Number of patches to select.
eps (float): Parameter for spare projection aggression.
"""
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print("Fitting random projections...")
try:
transformer = random_projection.SparseRandomProjection(eps=eps)
sample_set = torch.tensor(transformer.fit_transform(embedding.cpu())).to( # pylint: disable=not-callable
embedding.device
)
except ValueError:
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print(" Error: could not project vectors. Please increase `eps` value.")

select_idx = 0
last_item = sample_set[select_idx : select_idx + 1]
coreset_idx = [torch.tensor(select_idx).to(embedding.device)] # pylint: disable=not-callable
min_distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True)

for _ in range(sample_count - 1):
distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True) # broadcast
min_distances = torch.minimum(distances, min_distances) # iterate
select_idx = torch.argmax(min_distances) # select

last_item = sample_set[select_idx : select_idx + 1]
min_distances[select_idx] = 0
coreset_idx.append(select_idx)

coreset_idx = torch.stack(coreset_idx)
self.memory_bank = embedding[coreset_idx]


class PatchcoreLightning(AnomalyModule):
Expand Down Expand Up @@ -292,12 +311,10 @@ def training_epoch_end(self, outputs):
outputs (List[Dict[str, np.ndarray]]): List of embedding vectors
"""
embedding = torch.vstack([output["embedding"] for output in outputs])
sampling_ratio = self.hparams.model.coreset_sampling_ratio

embedding = self.model.subsample_embedding(embedding, sampling_ratio)
sampling_ratio = self.hparams.model.coreset_sampling_ratio

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding
self.model.create_coreset(embedding=embedding, sample_count=int(sampling_ratio * embedding.shape[0]), eps=0.9)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -311,7 +328,8 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, _ = self.model(batch["image"])
anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 0 additions & 1 deletion anomalib/models/patchcore/utils/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions anomalib/models/patchcore/utils/sampling/__init__.py

This file was deleted.

152 changes: 0 additions & 152 deletions anomalib/models/patchcore/utils/sampling/k_center_greedy.py

This file was deleted.

62 changes: 0 additions & 62 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py

This file was deleted.

Loading