Skip to content

Commit

Permalink
🛠 Fix PatchCore image-level score computation (#580)
Browse files Browse the repository at this point in the history
* fix patchcore image-level score computation

* docstring and comment

* remove default value for n_neighbors

* torch.Tensor -> Tensor
  • Loading branch information
djdameln committed Sep 26, 2022
1 parent 353d981 commit 6c59e1b
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 61 deletions.
57 changes: 11 additions & 46 deletions anomalib/models/patchcore/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import nn
from torch import Tensor, nn

from anomalib.models.components import GaussianBlur2d

Expand All @@ -26,67 +26,32 @@ def __init__(
kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1)

def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor:
def compute_anomaly_map(self, patch_scores: Tensor) -> torch.Tensor:
"""Pixel Level Anomaly Heatmap.
Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
feature_map_shape (torch.Size): 2-D feature map shape (width, height)
patch_scores (Tensor): Patch-level anomaly scores
Returns:
torch.Tensor: Map of the pixel-level anomaly scores
"""
width, height = feature_map_shape
batch_size = len(patch_scores) // (width * height)

anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height))
anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1]))

anomaly_map = F.interpolate(patch_scores, size=(self.input_size[0], self.input_size[1]))
anomaly_map = self.blur(anomaly_map)

return anomaly_map

@staticmethod
def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor:
"""Compute Image-Level Anomaly Score.
Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
Returns:
torch.Tensor: Image-level anomaly scores
"""
max_scores = torch.argmax(patch_scores[:, 0])
confidence = torch.index_select(patch_scores, 0, max_scores)
weights = 1 - torch.max(F.softmax(confidence, dim=-1))
score = weights * torch.max(patch_scores[:, 0])
return score

def forward(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, patch_scores: Tensor) -> Tensor:
"""Returns anomaly_map and anomaly_score.
Expects `patch_scores` keyword to be passed explicitly
Expects `feature_map_shape` keyword to be passed explicitly
Args:
patch_scores (Tensor): Patch-level anomaly scores
Example
>>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)
>>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape)
Raises:
ValueError: If `patch_scores` key is not found
>>> map = anomaly_map_generator(patch_scores=patch_scores)
Returns:
Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score
Tensor: anomaly_map
"""

if "patch_scores" not in kwargs:
raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}")

if "feature_map_shape" not in kwargs:
raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}")

patch_scores = kwargs["patch_scores"]
feature_map_shape = kwargs["feature_map_shape"]

anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape)
anomaly_score = self.compute_anomaly_score(patch_scores)
return anomaly_map, anomaly_score
anomaly_map = self.compute_anomaly_map(patch_scores)
return anomaly_map
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ

anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)
batch["pred_scores"] = anomaly_score

return batch

Expand Down
64 changes: 50 additions & 14 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

self.register_buffer("memory_bank", torch.Tensor())
self.memory_bank: torch.Tensor
self.register_buffer("memory_bank", Tensor())
self.memory_bank: Tensor

def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
def forward(self, input_tensor: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Return Embedding during training, or a tuple of anomaly map and anomaly score during testing.
Steps performed:
Expand All @@ -56,7 +56,7 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
input_tensor (Tensor): Input tensor
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Embedding for training,
Union[Tensor, Tuple[Tensor, Tensor]]: Embedding for training,
anomaly map and anomaly score for testing.
"""
if self.tiler:
Expand All @@ -71,21 +71,29 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.tiler:
embedding = self.tiler.untile(embedding)

feature_map_shape = embedding.shape[-2:]
batch_size, _, width, height = embedding.shape
embedding = self.reshape_embedding(embedding)

if self.training:
output = embedding
else:
patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=self.num_neighbors)
anomaly_map, anomaly_score = self.anomaly_map_generator(
patch_scores=patch_scores, feature_map_shape=feature_map_shape
)
# apply nearest neighbor search
patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1)
# reshape to batch dimension
patch_scores = patch_scores.reshape((batch_size, -1))
locations = locations.reshape((batch_size, -1))
# compute anomaly score
anomaly_score = self.compute_anomaly_score(patch_scores, locations, embedding)
# reshape to w, h
patch_scores = patch_scores.reshape((batch_size, 1, width, height))
# get anomaly map
anomaly_map = self.anomaly_map_generator(patch_scores)

output = (anomaly_map, anomaly_score)

return output

def generate_embedding(self, features: Dict[str, Tensor]) -> torch.Tensor:
def generate_embedding(self, features: Dict[str, Tensor]) -> Tensor:
"""Generate embedding from hierarchical feature map.
Args:
Expand Down Expand Up @@ -121,7 +129,7 @@ def reshape_embedding(embedding: Tensor) -> Tensor:
embedding = embedding.permute(0, 2, 3, 1).reshape(-1, embedding_size)
return embedding

def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) -> None:
def subsample_embedding(self, embedding: Tensor, sampling_ratio: float) -> None:
"""Subsample embedding based on coreset sampling and store to memory.
Args:
Expand All @@ -134,7 +142,7 @@ def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) ->
coreset = sampler.sample_coreset()
self.memory_bank = coreset

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor:
def nearest_neighbors(self, embedding: Tensor, n_neighbors: int) -> Tuple[Tensor, Tensor]:
"""Nearest Neighbours using brute force method and euclidean norm.
Args:
Expand All @@ -143,7 +151,35 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor:
Returns:
Tensor: Patch scores.
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores, locations

def compute_anomaly_score(self, patch_scores: Tensor, locations: Tensor, embedding: Tensor) -> Tensor:
"""Compute Image-Level Anomaly Score.
Args:
patch_scores (Tensor): Patch-level anomaly scores
locations: Memory bank locations of the nearest neighbor for each patch location
embedding: The feature embeddings that generated the patch scores
Returns:
Tensor: Image-level anomaly scores
"""

# 1. Find the patch with the largest distance to it's nearest neighbor in each image
max_patches = torch.argmax(patch_scores, dim=1) # (m^test,* in the paper)
# 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank
score = patch_scores[torch.arange(len(patch_scores)), max_patches] # s in the paper
nn_index = locations[torch.arange(len(patch_scores)), max_patches] # m^* in the paper
# 3. Find the support samples of the nearest neighbor in the membank
nn_sample = self.memory_bank[nn_index, :]
_, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors) # N_b(m^*) in the paper
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(embedding[max_patches].unsqueeze(1), self.memory_bank[support_samples], p=2.0)
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze()))[..., 0]
# 6. Apply the weight factor to the score
score = weights * score # S^* in the paper
return score

0 comments on commit 6c59e1b

Please sign in to comment.