Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix patchcore image-level score computation #580

Merged
merged 4 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 11 additions & 46 deletions anomalib/models/patchcore/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import nn
from torch import Tensor, nn

from anomalib.models.components import GaussianBlur2d

Expand All @@ -26,67 +26,32 @@ def __init__(
kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1)

def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor:
def compute_anomaly_map(self, patch_scores: Tensor) -> torch.Tensor:
"""Pixel Level Anomaly Heatmap.

Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
feature_map_shape (torch.Size): 2-D feature map shape (width, height)
patch_scores (Tensor): Patch-level anomaly scores

Returns:
torch.Tensor: Map of the pixel-level anomaly scores
"""
width, height = feature_map_shape
batch_size = len(patch_scores) // (width * height)

anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height))
anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1]))

anomaly_map = F.interpolate(patch_scores, size=(self.input_size[0], self.input_size[1]))
anomaly_map = self.blur(anomaly_map)

return anomaly_map

@staticmethod
def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor:
"""Compute Image-Level Anomaly Score.

Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
Returns:
torch.Tensor: Image-level anomaly scores
"""
max_scores = torch.argmax(patch_scores[:, 0])
confidence = torch.index_select(patch_scores, 0, max_scores)
weights = 1 - torch.max(F.softmax(confidence, dim=-1))
score = weights * torch.max(patch_scores[:, 0])
return score

def forward(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, patch_scores: Tensor) -> Tensor:
"""Returns anomaly_map and anomaly_score.

Expects `patch_scores` keyword to be passed explicitly
Expects `feature_map_shape` keyword to be passed explicitly
Args:
patch_scores (Tensor): Patch-level anomaly scores

Example
>>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)
>>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape)

Raises:
ValueError: If `patch_scores` key is not found
>>> map = anomaly_map_generator(patch_scores=patch_scores)

Returns:
Tuple[torch.Tensor, torch.Tensor]: anomaly_map, anomaly_score
Tensor: anomaly_map
"""

if "patch_scores" not in kwargs:
raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}")

if "feature_map_shape" not in kwargs:
raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}")

patch_scores = kwargs["patch_scores"]
feature_map_shape = kwargs["feature_map_shape"]

anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape)
anomaly_score = self.compute_anomaly_score(patch_scores)
return anomaly_map, anomaly_score
anomaly_map = self.compute_anomaly_map(patch_scores)
return anomaly_map
2 changes: 1 addition & 1 deletion anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ

anomaly_maps, anomaly_score = self.model(batch["image"])
batch["anomaly_maps"] = anomaly_maps
batch["pred_scores"] = anomaly_score.unsqueeze(0)
batch["pred_scores"] = anomaly_score

return batch

Expand Down
52 changes: 44 additions & 8 deletions anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,16 +71,24 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.tiler:
embedding = self.tiler.untile(embedding)

feature_map_shape = embedding.shape[-2:]
batch_size, _, width, height = embedding.shape
embedding = self.reshape_embedding(embedding)

if self.training:
output = embedding
else:
patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=self.num_neighbors)
anomaly_map, anomaly_score = self.anomaly_map_generator(
patch_scores=patch_scores, feature_map_shape=feature_map_shape
)
# apply nearest neighbor search
patch_scores, locations = self.nearest_neighbors(embedding=embedding, n_neighbors=1)
# reshape to batch dimension
patch_scores = patch_scores.reshape((batch_size, -1))
locations = locations.reshape((batch_size, -1))
# compute anomaly score
anomaly_score = self.compute_anomaly_score(patch_scores, locations, embedding)
# reshape to w, h
patch_scores = patch_scores.reshape((batch_size, 1, width, height))
# get anomaly map
anomaly_map = self.anomaly_map_generator(patch_scores)

output = (anomaly_map, anomaly_score)

return output
Expand Down Expand Up @@ -134,7 +142,7 @@ def subsample_embedding(self, embedding: torch.Tensor, sampling_ratio: float) ->
coreset = sampler.sample_coreset()
self.memory_bank = coreset

def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor:
def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tuple[Tensor, Tensor]:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
"""Nearest Neighbours using brute force method and euclidean norm.

Args:
Expand All @@ -143,7 +151,35 @@ def nearest_neighbors(self, embedding: Tensor, n_neighbors: int = 9) -> Tensor:

Returns:
Tensor: Patch scores.
Tensor: Locations of the nearest neighbor(s).
"""
distances = torch.cdist(embedding, self.memory_bank, p=2.0) # euclidean norm
patch_scores, _ = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores
patch_scores, locations = distances.topk(k=n_neighbors, largest=False, dim=1)
return patch_scores, locations

def compute_anomaly_score(self, patch_scores: torch.Tensor, locations: Tensor, embedding: Tensor) -> torch.Tensor:
djdameln marked this conversation as resolved.
Show resolved Hide resolved
"""Compute Image-Level Anomaly Score.

Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
locations: Memory bank locations of the nearest neighbor for each patch location
embedding: The feature embeddings that generated the patch scores
Returns:
torch.Tensor: Image-level anomaly scores
"""

# 1. Find the patch with the largest distance to it's nearest neighbor in each image
max_patches = torch.argmax(patch_scores, dim=1) # (m^test,* in the paper)
# 2. Find the distance of the patch to it's nearest neighbor, and the location of the nn in the membank
score = patch_scores[torch.arange(len(patch_scores)), max_patches] # s in the paper
nn_index = locations[torch.arange(len(patch_scores)), max_patches] # m^* in the paper
# 3. Find the support samples of the nearest neighbor in the membank
nn_sample = self.memory_bank[nn_index, :]
_, support_samples = self.nearest_neighbors(nn_sample, n_neighbors=self.num_neighbors) # N_b(m^*) in the paper
# 4. Find the distance of the patch features to each of the support samples
distances = torch.cdist(embedding[max_patches].unsqueeze(1), self.memory_bank[support_samples], p=2.0)
# 5. Apply softmax to find the weights
weights = (1 - F.softmax(distances.squeeze()))[..., 0]
# 6. Apply the weight factor to the score
score = weights * score # S^* in the paper
return score