Skip to content

Commit

Permalink
Calculate feature map shape patchcore (#148)
Browse files Browse the repository at this point in the history
* calculate feature map shape patchcore

* passed to wrong function

* unpack values from torch tensor directly

* get feature map shape before reshaping

* updating docstrings in patchcore

* fix line length

* w, h (too short) -> width, height

* typing torch.Size of feature_map_shape

* removing trailing whitespaces
  • Loading branch information
alexriedel1 committed Mar 22, 2022
1 parent b66e5e3 commit b57e025
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions anomalib/models/patchcore/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,20 @@ def __init__(
self.input_size = input_size
self.sigma = sigma

def compute_anomaly_map(self, patch_scores: torch.Tensor) -> torch.Tensor:
def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor:
"""Pixel Level Anomaly Heatmap.
Args:
patch_scores (torch.Tensor): Patch-level anomaly scores
feature_map_shape (torch.Size): 2-D feature map shape (width, height)
Returns:
torch.Tensor: Map of the pixel-level anomaly scores
"""
# TODO: https://github.com/openvinotoolkit/anomalib/issues/40
batch_size = len(patch_scores) // (28 * 28)
width, height = feature_map_shape
batch_size = len(patch_scores) // (width * height)

anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, 28, 28))
anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height))
anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1]))

kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1
Expand All @@ -84,10 +86,11 @@ def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns anomaly_map and anomaly_score.
Expects `patch_scores` keyword to be passed explicitly
Expects `feature_map_shape` keyword to be passed explicitly
Example
>>> anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)
>>> map, score = anomaly_map_generator(patch_scores=numpy_array)
>>> map, score = anomaly_map_generator(patch_scores=numpy_array, feature_map_shape=feature_map_shape)
Raises:
ValueError: If `patch_scores` key is not found
Expand All @@ -99,8 +102,13 @@ def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if "patch_scores" not in kwargs:
raise ValueError(f"Expected key `patch_scores`. Found {kwargs.keys()}")

if "feature_map_shape" not in kwargs:
raise ValueError(f"Expected key `feature_map_shape`. Found {kwargs.keys()}")

patch_scores = kwargs["patch_scores"]
anomaly_map = self.compute_anomaly_map(patch_scores)
feature_map_shape = kwargs["feature_map_shape"]

anomaly_map = self.compute_anomaly_map(patch_scores, feature_map_shape)
anomaly_score = self.compute_anomaly_score(patch_scores)
return anomaly_map, anomaly_score

Expand Down Expand Up @@ -163,13 +171,16 @@ def forward(self, input_tensor: Tensor) -> Union[torch.Tensor, Tuple[torch.Tenso
if self.apply_tiling:
embedding = self.tiler.untile(embedding)

feature_map_shape = embedding.shape[-2:]
embedding = self.reshape_embedding(embedding)

if self.training:
output = embedding
else:
patch_scores = self.nearest_neighbors(embedding=embedding, n_neighbors=9)
anomaly_map, anomaly_score = self.anomaly_map_generator(patch_scores=patch_scores)
anomaly_map, anomaly_score = self.anomaly_map_generator(
patch_scores=patch_scores, feature_map_shape=feature_map_shape
)
output = (anomaly_map, anomaly_score)

return output
Expand Down

0 comments on commit b57e025

Please sign in to comment.