Skip to content

Commit

Permalink
Fix onnx export by rewriting GaussianBlur (#476)
Browse files Browse the repository at this point in the history
* Fix onnx export by rewriting GaussianBlur

* Address codacy complaints.

Reame variable to something other than `input`

* Move GaussianBlur2d to anomalib.post_processing

* Move blur to `anomlib.models.components.filters`
  • Loading branch information
ORippler committed Aug 4, 2022
1 parent 6dfb283 commit e19428f
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 18 deletions.
2 changes: 2 additions & 0 deletions anomalib/models/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .base import AnomalyModule, DynamicBufferModule
from .dimensionality_reduction import PCA, SparseRandomProjection
from .feature_extractors import FeatureExtractor
from .filters import GaussianBlur2d
from .sampling import KCenterGreedy
from .stats import GaussianKDE, MultiVariateGaussian

Expand All @@ -28,5 +29,6 @@
"FeatureExtractor",
"KCenterGreedy",
"GaussianKDE",
"GaussianBlur2d",
"MultiVariateGaussian",
]
5 changes: 5 additions & 0 deletions anomalib/models/components/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Implements filters used by models."""

from .blur import GaussianBlur2d

__all__ = ["GaussianBlur2d"]
76 changes: 76 additions & 0 deletions anomalib/models/components/filters/blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Gaussian blurring via pytorch."""
from typing import Tuple, Union

from kornia.filters import get_gaussian_kernel2d
from kornia.filters.filter import _compute_padding
from kornia.filters.kernels import normalize_kernel2d
from torch import Tensor, nn
from torch.nn import functional as F


class GaussianBlur2d(nn.Module):
"""Compute GaussianBlur in 2d.
Makes use of kornia functions, but most notably the kernel is not computed
during the forward pass, and does not depend on the input size. As a caveat,
the number of channels that are expected have to be provided during initialization.
"""

def __init__(
self,
kernel_size: Union[Tuple[int, int], int],
sigma: Union[Tuple[float, float], float],
channels: int,
normalize: bool = True,
border_type: str = "reflect",
padding: str = "same",
) -> None:
"""Initialize model, setup kernel etc..
Args:
kernel_size (Union[Tuple[int, int], int]): size of the Gaussian kernel to use.
sigma (Union[Tuple[float, float], float]): standard deviation to use for constructing the Gaussian kernel.
channels (int): channels of the input
normalize (bool, optional): Whether to normalize the kernel or not (i.e. all elements sum to 1).
Defaults to True.
border_type (str, optional): Border type to use for padding of the input. Defaults to "reflect".
padding (str, optional): Type of padding to apply. Defaults to "same".
"""
super().__init__()
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
sigma = sigma if isinstance(sigma, tuple) else (sigma, sigma)
self.kernel: Tensor
self.register_buffer("kernel", get_gaussian_kernel2d(kernel_size=kernel_size, sigma=sigma))
if normalize:
self.kernel = normalize_kernel2d(self.kernel)
self.channels = channels
self.kernel.unsqueeze_(0).unsqueeze_(0)
self.kernel = self.kernel.expand(self.channels, -1, -1, -1)
self.border_type = border_type
self.padding = padding
self.height, self.width = self.kernel.shape[-2:]
self.padding_shape = _compute_padding([self.height, self.width])

def forward(self, input_tensor: Tensor) -> Tensor:
"""Blur the input with the computed Gaussian.
Args:
input_tensor (Tensor): Input tensor to be blurred.
Returns:
Tensor: Blurred output tensor.
"""
batch, channel, height, width = input_tensor.size()

if self.padding == "same":
input_tensor = F.pad(input_tensor, self.padding_shape, mode=self.border_type)

# convolve the tensor with the kernel.
output = F.conv2d(input_tensor, self.kernel, groups=self.channels, padding=0, stride=1)

if self.padding == "same":
out = output.view(batch, channel, height, width)
else:
out = output.view(batch, channel, height - self.height + 1, width - self.width + 1)

return out
24 changes: 12 additions & 12 deletions anomalib/models/padim/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

import torch
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn

from anomalib.models.components import GaussianBlur2d

class AnomalyMapGenerator:

class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap.
Args:
Expand All @@ -32,8 +33,10 @@ class AnomalyMapGenerator:
"""

def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4):
super().__init__()
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.sigma = sigma
kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1)

@staticmethod
def compute_distance(embedding: Tensor, stats: List[Tensor]) -> Tensor:
Expand All @@ -57,7 +60,7 @@ def compute_distance(embedding: Tensor, stats: List[Tensor]) -> Tensor:
delta = (embedding - mean).permute(2, 0, 1)

distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0)
distances = distances.reshape(batch, height, width)
distances = distances.reshape(batch, 1, height, width)
distances = distances.clamp(0).sqrt()

return distances
Expand All @@ -73,7 +76,7 @@ def up_sample(self, distance: Tensor) -> Tensor:
"""

score_map = F.interpolate(
distance.unsqueeze(1),
distance,
size=self.image_size,
mode="bilinear",
align_corners=False,
Expand All @@ -90,11 +93,8 @@ def smooth_anomaly_map(self, anomaly_map: Tensor) -> Tensor:
Filtered anomaly scores
"""

kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1
sigma = torch.as_tensor(self.sigma).to(anomaly_map.device)
anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(sigma, sigma))

return anomaly_map
blurred_anomaly_map = self.blur(anomaly_map)
return blurred_anomaly_map

def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: Tensor) -> Tensor:
"""Compute anomaly score.
Expand All @@ -120,7 +120,7 @@ def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: T

return smoothed_anomaly_map

def __call__(self, **kwds):
def forward(self, **kwds):
"""Returns anomaly_map.
Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly.
Expand Down
15 changes: 9 additions & 6 deletions anomalib/models/patchcore/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,24 @@

import torch
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
from omegaconf import ListConfig
from torch import nn

from anomalib.models.components import GaussianBlur2d

class AnomalyMapGenerator:

class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(
self,
input_size: Union[ListConfig, Tuple],
sigma: int = 4,
) -> None:
super().__init__()
self.input_size = input_size
self.sigma = sigma
kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
self.blur = GaussianBlur2d(kernel_size=(kernel_size, kernel_size), sigma=(sigma, sigma), channels=1)

def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: torch.Size) -> torch.Tensor:
"""Pixel Level Anomaly Heatmap.
Expand All @@ -49,8 +53,7 @@ def compute_anomaly_map(self, patch_scores: torch.Tensor, feature_map_shape: tor
anomaly_map = patch_scores[:, 0].reshape((batch_size, 1, width, height))
anomaly_map = F.interpolate(anomaly_map, size=(self.input_size[0], self.input_size[1]))

kernel_size = 2 * int(4.0 * self.sigma + 0.5) + 1
anomaly_map = gaussian_blur2d(anomaly_map, (kernel_size, kernel_size), sigma=(self.sigma, self.sigma))
anomaly_map = self.blur(anomaly_map)

return anomaly_map

Expand All @@ -69,7 +72,7 @@ def compute_anomaly_score(patch_scores: torch.Tensor) -> torch.Tensor:
score = weights * torch.max(patch_scores[:, 0])
return score

def __call__(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward(self, **kwargs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Returns anomaly_map and anomaly_score.
Expects `patch_scores` keyword to be passed explicitly
Expand Down
1 change: 1 addition & 0 deletions tests/pre_merge/models/components/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test individual components."""
18 changes: 18 additions & 0 deletions tests/pre_merge/models/components/test_blur.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import torch
from kornia.filters import GaussianBlur2d as korniaGaussianBlur2d

from anomalib.models.components import GaussianBlur2d


@pytest.mark.parametrize("kernel_size", [(33, 33), (9, 9), (11, 5), (3, 3)])
@pytest.mark.parametrize("sigma", [(4.0, 4.0), (1.9, 3.0), (2.0, 1.5)])
@pytest.mark.parametrize("channels", list(range(1, 6)))
def test_blur_equivalence(kernel_size, sigma, channels):
for _ in range(10):
input_tensor = torch.randn((3, channels, 128, 128))
kornia = korniaGaussianBlur2d(kernel_size, sigma, separable=False)
blur_kornia = kornia(input_tensor)
gaussian = GaussianBlur2d(kernel_size, sigma, channels)
blur_gaussian = gaussian(input_tensor)
torch.testing.assert_allclose(blur_kornia, blur_gaussian)

1 comment on commit e19428f

@voidmain443
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#1798 error is seems with the GaussianBlur2d with channels dimensions in the tensor is note equal and it seems working with the Pacthcore() . Can you explain with this Pacthcore library how it works/?

Please sign in to comment.