Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert AnomalyMapGenerator to nn.Module #497

Merged
merged 1 commit into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions anomalib/models/cflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(
self,
image_size: Union[ListConfig, Tuple],
pool_layers: List[str],
):
super().__init__()
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True)
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.pool_layers: List[str] = pool_layers
Expand Down Expand Up @@ -60,7 +61,7 @@ def compute_anomaly_map(

return anomaly_map

def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor:
def forward(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor:
"""Returns anomaly_map.

Expects `distribution`, `height` and 'width' keywords to be passed explicitly
Expand Down
7 changes: 4 additions & 3 deletions anomalib/models/fastflow/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(self, input_size: Union[ListConfig, Tuple]):
super().__init__()
self.input_size = input_size if isinstance(input_size, tuple) else tuple(input_size)

def __call__(self, hidden_variables: List[Tensor]) -> Tensor:
def forward(self, hidden_variables: List[Tensor]) -> Tensor:
"""Generate Anomaly Heatmap.

This implementation generates the heatmap based on the flow maps
Expand Down
12 changes: 6 additions & 6 deletions anomalib/models/padim/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def compute_anomaly_map(self, embedding: Tensor, mean: Tensor, inv_covariance: T

return smoothed_anomaly_map

def forward(self, **kwds):
def forward(self, **kwargs):
"""Returns anomaly_map.

Expects `embedding`, `mean` and `covariance` keywords to be passed explicitly.
Expand All @@ -125,11 +125,11 @@ def forward(self, **kwds):
torch.Tensor: anomaly map
"""

if not ("embedding" in kwds and "mean" in kwds and "inv_covariance" in kwds):
raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwds.keys()}")
if not ("embedding" in kwargs and "mean" in kwargs and "inv_covariance" in kwargs):
raise ValueError(f"Expected keys `embedding`, `mean` and `covariance`. Found {kwargs.keys()}")

embedding: Tensor = kwds["embedding"]
mean: Tensor = kwds["mean"]
inv_covariance: Tensor = kwds["inv_covariance"]
embedding: Tensor = kwargs["embedding"]
mean: Tensor = kwargs["mean"]
inv_covariance: Tensor = kwargs["inv_covariance"]

return self.compute_anomaly_map(embedding, mean, inv_covariance)
7 changes: 4 additions & 3 deletions anomalib/models/reverse_distillation/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap.

Args:
Expand All @@ -32,6 +32,7 @@ class AnomalyMapGenerator:
"""

def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4, mode: str = "multiply"):
super().__init__()
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.sigma = sigma
self.kernel_size = 2 * int(4.0 * sigma + 0.5) + 1
Expand All @@ -40,7 +41,7 @@ def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4, mode: s
raise ValueError(f"Found mode {mode}. Only multiply and add are supported.")
self.mode = mode

def __call__(self, student_features: List[Tensor], teacher_features: List[Tensor]) -> Tensor:
def forward(self, student_features: List[Tensor], teacher_features: List[Tensor]) -> Tensor:
"""Computes anomaly map given encoder and decoder features.

Args:
Expand Down
15 changes: 8 additions & 7 deletions anomalib/models/stfpm/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor
from torch import Tensor, nn


class AnomalyMapGenerator:
class AnomalyMapGenerator(nn.Module):
"""Generate Anomaly Heatmap."""

def __init__(
self,
image_size: Union[ListConfig, Tuple],
):
super().__init__()
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True)
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)

Expand Down Expand Up @@ -59,7 +60,7 @@ def compute_anomaly_map(

return anomaly_map

def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor:
def forward(self, **kwargs: Dict[str, Tensor]) -> torch.Tensor:
"""Returns anomaly map.

Expects `teach_features` and `student_features` keywords to be passed explicitly.
Expand All @@ -78,10 +79,10 @@ def __call__(self, **kwds: Dict[str, Tensor]) -> torch.Tensor:
torch.Tensor: anomaly map
"""

if not ("teacher_features" in kwds and "student_features" in kwds):
raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwds.keys()}")
if not ("teacher_features" in kwargs and "student_features" in kwargs):
raise ValueError(f"Expected keys `teacher_features` and `student_features. Found {kwargs.keys()}")

teacher_features: Dict[str, Tensor] = kwds["teacher_features"]
student_features: Dict[str, Tensor] = kwds["student_features"]
teacher_features: Dict[str, Tensor] = kwargs["teacher_features"]
student_features: Dict[str, Tensor] = kwargs["student_features"]

return self.compute_anomaly_map(teacher_features, student_features)
8 changes: 4 additions & 4 deletions tests/helpers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ def __init__(

def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
def inner(*args, **kwargs):
# If true, will use MVTech AD dataset for testing.
# Useful for nightly builds
if self.use_mvtec:
return func(*args, path=self.path, **kwds)
return func(*args, path=self.path, **kwargs)
else:
with GeneratedDummyDataset(
num_train=self.num_train,
Expand All @@ -145,8 +145,8 @@ def inner(*args, **kwds):
max_size=self.max_size,
seed=self.seed,
) as dataset_path:
kwds["category"] = "shapes"
return func(*args, path=dataset_path, **kwds)
kwargs["category"] = "shapes"
return func(*args, path=dataset_path, **kwargs)

return inner

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ def test_dataloader(self) -> DataLoader:
return DataLoader(DummyDataset())


class DummyAnomalyMapGenerator:
class DummyAnomalyMapGenerator(nn.Module):
def __init__(self):
super().__init__()
self.input_size = (100, 100)
self.sigma = 4

Expand Down