Skip to content

Commit

Permalink
Updated coreset subsampling method to improve accuracy (#73)
Browse files Browse the repository at this point in the history
* Updated coreset subsampling method to improve accuracy

* fixing tox issues

* Addressing review comments

* Address review comment - move coreset method to model class
  • Loading branch information
blakshma committed Jan 18, 2022
1 parent 2b40df7 commit 3613a6e
Show file tree
Hide file tree
Showing 9 changed files with 88 additions and 435 deletions.
70 changes: 35 additions & 35 deletions README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions anomalib/models/patchcore/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,22 @@ All results gathered with seed `42`.

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.819 | 0.947 | 0.722 | 0.997 | 0.982 | 0.988 | 0.972 | 0.810 | 0.586 | 0.981 | 0.631 | 0.780 | 0.482 | 0.827 | 0.733 | 0.844 |
| Wide ResNet-50 | 0.877 | 0.981 | 0.842 | 1.0 | 0.991 | 0.991 | 0.985 | 0.868 | 0.763 | 0.988 | 0.914 | 0.769 | 0.427 | 0.806 | 0.878 | 0.958 |
| Wide ResNet-50 | 0.980 | 0.984 | 0.959 | 1.000 | 1.000 | 0.989 | 1.000 | 0.990 | 0.982 | 1.000 | 0.994 | 0.924 | 0.960 | 0.933 | 1.000 | 0.982 |
| ResNet-18 | 0.973 | 0.970 | 0.947 | 1.000 | 0.997 | 0.997 | 1.000 | 0.986 | 0.965 | 1.000 | 0.991 | 0.916 | 0.943 | 0.931 | 0.996 | 0.953 |

### Pixel-Level AUC

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.935 | 0.979 | 0.843 | 0.989 | 0.934 | 0.925 | 0.956 | 0.923 | 0.942 | 0.967 | 0.913 | 0.931 | 0.924 | 0.958 | 0.881 | 0.954 |
| Wide ResNet-50 | 0.955 | 0.988 | 0.903 | 0.990 | 0.957 | 0.936 | 0.972 | 0.950 | 0.968 | 0.974 | 0.960 | 0.948 | 0.917 | 0.969 | 0.913 | 0.976 |
| Wide ResNet-50 | 0.980 | 0.988 | 0.968 | 0.991 | 0.961 | 0.934 | 0.984 | 0.988 | 0.988 | 0.987 | 0.989 | 0.980 | 0.989 | 0.988 | 0.981 | 0.983 |
| ResNet-18 | 0.976 | 0.986 | 0.955 | 0.990 | 0.943 | 0.933 | 0.981 | 0.984 | 0.986 | 0.986 | 0.986 | 0.974 | 0.991 | 0.988 | 0.974 | 0.983 |

### Image F1 Score

| | Avg | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal Nut | Pill | Screw | Toothbrush | Transistor | Zipper |
| -------------- | :---: | :----: | :---: | :-----: | :---: | :---: | :----: | :---: | :-----: | :------: | :-------: | :---: | :---: | :--------: | :--------: | :----: |
| ResNet-18 | 0.896 | 0.933 | 0.857 | 0.995 | 0.964 | 0.983 | 0.959 | 0.790 | 0.908 | 0.964 | 0.903 | 0.916 | 0.853 | 0.866 | 0.653 | 0.898 |
| Wide ResNet-50 | 0.923 | 0.961 | 0.875 | 1.0 | 0.989 | 0.975 | 0.984 | 0.832 | 0.908 | 0.972 | 0.920 | 0.922 | 0.853 | 0.862 | 0.842 | 0.953 |
| Wide ResNet-50 | 0.976 | 0.971 | 0.974 | 1.000 | 1.000 | 0.967 | 1.000 | 0.968 | 0.982 | 1.000 | 0.984 | 0.940 | 0.943 | 0.938 | 1.000 | 0.979 |
| ResNet-18 | 0.970 | 0.949 | 0.946 | 1.000 | 0.982 | 0.992 | 1.000 | 0.978 | 0.969 | 1.000 | 0.989 | 0.940 | 0.932 | 0.935 | 0.974 | 0.967 |

### Sample Results

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ model:
layers:
- layer2
- layer3
coreset_sampling_ratio: 0.001
coreset_sampling_ratio: 0.1
num_neighbors: 9
metric: auc
weight_file: weights/model.ckpt
Expand Down
74 changes: 46 additions & 28 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,13 @@
import torchvision
from kornia import gaussian_blur2d
from omegaconf import ListConfig
from sklearn import random_projection
from torch import Tensor, nn

from anomalib.core.model import AnomalyModule
from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.data.tiler import Tiler
from anomalib.models.patchcore.utils.sampling import (
KCenterGreedy,
NearestNeighbors,
SparseRandomProjection,
)


class AnomalyMapGenerator:
Expand Down Expand Up @@ -127,7 +123,6 @@ def __init__(

self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=self.layers)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.nn_search = NearestNeighbors(n_neighbors=9)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

if apply_tiling:
Expand Down Expand Up @@ -170,7 +165,8 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.training:
output = embedding
else:
patch_scores, _ = self.nn_search.kneighbors(embedding)
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=9, largest=False, dim=1)

anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
output = (anomaly_map, anomaly_score)
Expand Down Expand Up @@ -213,25 +209,48 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

@staticmethod
def subsample_embedding(embedding: torch.Tensor, sampling_ratio: float) -> torch.Tensor:
"""Subsample embedding based on coreset sampling.
def create_coreset(
self,
embedding: Tensor,
sample_count: int = 500,
eps: float = 0.90,
):
"""Creates n subsampled coreset for given sample_set.
Args:
embedding (np.ndarray): Embedding tensor from the CNN
sampling_ratio (float): Coreset sampling ratio
Returns:
np.ndarray: Subsampled embedding whose dimensionality is reduced.
embedding (Tensor): (sample_count, d) tensor of patches.
sample_count (int): Number of patches to select.
eps (float): Parameter for spare projection aggression.
"""
# Random projection
random_projector = SparseRandomProjection(eps=0.9)
random_projector.fit(embedding)

# Coreset Subsampling
sampler = KCenterGreedy(model=random_projector, embedding=embedding, sampling_ratio=sampling_ratio)
coreset = sampler.sample_coreset()
return coreset
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print("Fitting random projections...")
try:
transformer = random_projection.SparseRandomProjection(eps=eps)
sample_set = torch.tensor(transformer.fit_transform(embedding.cpu())).to( # pylint: disable=not-callable
embedding.device
)
except ValueError:
# TODO: https://github.com/openvinotoolkit/anomalib/issues/54
# Replace print statement with logger.
print(" Error: could not project vectors. Please increase `eps` value.")

select_idx = 0
last_item = sample_set[select_idx : select_idx + 1]
coreset_idx = [torch.tensor(select_idx).to(embedding.device)] # pylint: disable=not-callable
min_distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True)

for _ in range(sample_count - 1):
distances = torch.linalg.norm(sample_set - last_item, dim=1, keepdims=True) # broadcast
min_distances = torch.minimum(distances, min_distances) # iterate
select_idx = torch.argmax(min_distances) # select

last_item = sample_set[select_idx : select_idx + 1]
min_distances[select_idx] = 0
coreset_idx.append(select_idx)

coreset_idx = torch.stack(coreset_idx)
self.memory_bank = embedding[coreset_idx]


class PatchcoreLightning(AnomalyModule):
Expand Down Expand Up @@ -292,12 +311,10 @@ def training_epoch_end(self, outputs):
outputs (List[Dict[str, np.ndarray]]): List of embedding vectors
"""
embedding = torch.vstack([output["embedding"] for output in outputs])
sampling_ratio = self.hparams.model.coreset_sampling_ratio

embedding = self.model.subsample_embedding(embedding, sampling_ratio)
sampling_ratio = self.hparams.model.coreset_sampling_ratio

self.model.nn_search.fit(embedding)
self.model.memory_bank = embedding
self.model.create_coreset(embedding=embedding, sample_count=int(sampling_ratio * embedding.shape[0]), eps=0.9)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Get batch of anomaly maps from input image batch.
Expand All @@ -311,7 +328,8 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Dict[str, Any]: Image filenames, test images, GT and predicted label/masks
"""

anomaly_maps, _ = self.model(batch["image"])
anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)

return batch
1 change: 0 additions & 1 deletion anomalib/models/patchcore/utils/__init__.py

This file was deleted.

7 changes: 0 additions & 7 deletions anomalib/models/patchcore/utils/sampling/__init__.py

This file was deleted.

152 changes: 0 additions & 152 deletions anomalib/models/patchcore/utils/sampling/k_center_greedy.py

This file was deleted.

62 changes: 0 additions & 62 deletions anomalib/models/patchcore/utils/sampling/nearest_neighbors.py

This file was deleted.

Loading

0 comments on commit 3613a6e

Please sign in to comment.