Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSPCAB implementation #500

Merged
merged 11 commits into from
Aug 15, 2022
8 changes: 8 additions & 0 deletions anomalib/models/components/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Neural network layers."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .sspcab import SSPCAB

__all__ = ["SSPCAB"]
81 changes: 81 additions & 0 deletions anomalib/models/components/layers/sspcab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""SSPCAB: Self-Supervised Predictive Convolutional Attention Block for reconstruction-based models.

Paper https://arxiv.org/abs/2111.09099
"""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
import torch.nn.functional as F
from torch import Tensor, nn


class AttentionModule(nn.Module):
"""Squeeze and excitation block that acts as the attention module in SSPCAB.

Args:
channels (int): Number of input channels.
reduction_ratio (int): Reduction ratio of the attention module.
"""

def __init__(self, in_channels: int, reduction_ratio: int = 8):
super().__init__()

out_channels = in_channels // reduction_ratio
self.fc1 = nn.Linear(in_channels, out_channels)
self.fc2 = nn.Linear(out_channels, in_channels)

def forward(self, inputs: Tensor) -> Tensor:
"""Forward pass through the attention module."""
# reduce feature map to 1d vector through global average pooling
avg_pooled = inputs.mean(dim=(2, 3))
djdameln marked this conversation as resolved.
Show resolved Hide resolved

# squeeze and excite
act = self.fc1(avg_pooled)
act = F.relu(act)
act = self.fc2(act)
act = F.sigmoid(act)

# multiply with input
se_out = inputs * act.view(act.shape[0], act.shape[1], 1, 1)

return se_out


class SSPCAB(nn.Module):
"""SSPCAB block.

Args:
in_channels (int): Number of input channels.
kernel_size (int): Size of the receptive fields of the masked convolution kernel.
dilation (int): Dilation factor of the masked convolution kernel.
reduction_ratio (int): Reduction ratio of the attention module.
"""

def __init__(self, in_channels: int, kernel_size: int = 1, dilation: int = 1, reduction_ratio: int = 8):
super().__init__()

self.pad = kernel_size + dilation
self.crop = 2 * (kernel_size + dilation)

self.masked_conv1 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv2 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv3 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)
self.masked_conv4 = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size)

self.attention_module = AttentionModule(in_channels=in_channels, reduction_ratio=reduction_ratio)

def forward(self, inputs: Tensor) -> Tensor:
"""Forward pass through the SSPCAB block."""
# compute masked convolution
padded = F.pad(inputs, (self.pad,) * 4)
masked_out = torch.zeros_like(inputs)
masked_out += self.masked_conv1(padded[..., : -self.crop, : -self.crop])
masked_out += self.masked_conv2(padded[..., : -self.crop, self.crop :])
masked_out += self.masked_conv3(padded[..., self.crop :, : -self.crop])
masked_out += self.masked_conv4(padded[..., self.crop :, self.crop :])

# apply channel attention module
sspcab_out = self.attention_module(masked_out)
return sspcab_out
2 changes: 2 additions & 0 deletions anomalib/models/draem/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ model:
name: draem
anomaly_source_path: null # optional, e.g. ./datasets/dtd
lr: 0.0001
enable_sspcab: false
sspcab_lambda: 0.1
Comment on lines +27 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When sspcab_lambda is set to 0, it is not added to the overall loss. In this case, can we assume that sspcab_lambda=0 would be the same as enable_sspcab=false?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if yes, it would be possible to just have a single argument, sspcab_lambda. When it is 0, it would be disabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be possible, but I think it would be more clear for the user to have an explicit parameter to enable/disable SSPCAB. Especially because we would have to add some logic to enable/disable the sspcab block in the architecture of the torch model based on the value of the lambda parameter. As a user I would not expect the value of the lambda parameter to affect the architecture of the model.

early_stopping:
patience: 20
metric: pixel_AUROC
Expand Down
46 changes: 42 additions & 4 deletions anomalib/models/draem/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import Optional, Union
from typing import Callable, Dict, Optional, Union

import torch
from omegaconf import DictConfig, ListConfig
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import Tensor, nn

from anomalib.models.components import AnomalyModule
from anomalib.models.draem.loss import DraemLoss
Expand All @@ -30,12 +31,40 @@ class Draem(AnomalyModule):
be used if left empty.
"""

def __init__(self, anomaly_source_path: Optional[str] = None):
def __init__(
self, enable_sspcab: bool = False, sspcab_lambda: float = 0.1, anomaly_source_path: Optional[str] = None
):
super().__init__()

self.augmenter = Augmenter(anomaly_source_path)
self.model = DraemModel()
self.model = DraemModel(sspcab=enable_sspcab)
self.loss = DraemLoss()
self.sspcab = enable_sspcab

if self.sspcab:
self.sspcab_activations: Dict = {}
self.setup_sspcab()
self.sspcab_loss = nn.MSELoss()
self.sspcab_lambda = sspcab_lambda

def setup_sspcab(self):
"""Prepare the model for the SSPCAB training step by adding forward hooks for the SSPCAB layer activations."""

def get_activation(name: str) -> Callable:
"""Retrieves the activations.

Args:
name (str): Identifier for the retrieved activations.
"""

def hook(_, __, output: Tensor):
"""Hook for retrieving the activations."""
self.sspcab_activations[name] = output

return hook

self.model.reconstructive_subnetwork.encoder.mp4.register_forward_hook(get_activation("input"))
self.model.reconstructive_subnetwork.encoder.block5.register_forward_hook(get_activation("output"))
djdameln marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of DRAEM.
Expand All @@ -56,6 +85,11 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
reconstruction, prediction = self.model(augmented_image)
# Compute loss
loss = self.loss(input_image, reconstruction, anomaly_mask, prediction)

if self.sspcab:
loss += self.sspcab_lambda * self.sspcab_loss(
self.sspcab_activations["input"], self.sspcab_activations["output"]
)
return {"loss": loss}

def validation_step(self, batch, _):
Expand All @@ -80,7 +114,11 @@ class DraemLightning(Draem):
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(anomaly_source_path=hparams.model.anomaly_source_path)
super().__init__(
enable_sspcab=hparams.model.enable_sspcab,
sspcab_lambda=hparams.model.sspcab_lambda,
anomaly_source_path=hparams.model.anomaly_source_path,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)

Expand Down
31 changes: 18 additions & 13 deletions anomalib/models/draem/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
import torch
from torch import Tensor, nn

from anomalib.models.components.layers import SSPCAB


class DraemModel(nn.Module):
"""DRAEM PyTorch model consisting of the reconstructive and discriminative sub networks."""

def __init__(self):
def __init__(self, sspcab: bool = False):
super().__init__()
self.reconstructive_subnetwork = ReconstructiveSubNetwork()
self.reconstructive_subnetwork = ReconstructiveSubNetwork(sspcab=sspcab)
self.discriminative_subnetwork = DiscriminativeSubNetwork(in_channels=6, out_channels=2)

def forward(self, batch: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
Expand Down Expand Up @@ -50,9 +52,9 @@ class ReconstructiveSubNetwork(nn.Module):
base_width (int): Base dimensionality of the layers of the autoencoder.
"""

def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128):
def __init__(self, in_channels: int = 3, out_channels: int = 3, base_width=128, sspcab: bool = False):
super().__init__()
self.encoder = EncoderReconstructive(in_channels, base_width)
self.encoder = EncoderReconstructive(in_channels, base_width, sspcab=sspcab)
self.decoder = DecoderReconstructive(base_width, out_channels=out_channels)

def forward(self, batch: Tensor) -> Tensor:
Expand Down Expand Up @@ -321,7 +323,7 @@ class EncoderReconstructive(nn.Module):
base_width (int): Base dimensionality of the layers of the autoencoder.
"""

def __init__(self, in_channels: int, base_width: int):
def __init__(self, in_channels: int, base_width: int, sspcab: bool = False):
super().__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels, base_width, kernel_size=3, padding=1),
Expand Down Expand Up @@ -359,14 +361,17 @@ def __init__(self, in_channels: int, base_width: int):
nn.ReLU(inplace=True),
)
self.mp4 = nn.Sequential(nn.MaxPool2d(2))
self.block5 = nn.Sequential(
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
)
if sspcab:
self.block5 = SSPCAB(base_width * 8)
else:
self.block5 = nn.Sequential(
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
nn.Conv2d(base_width * 8, base_width * 8, kernel_size=3, padding=1),
nn.BatchNorm2d(base_width * 8),
nn.ReLU(inplace=True),
)

def forward(self, batch: Tensor) -> Tensor:
"""Encode a batch of input images to the salient space.
Expand Down