Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add Reverse Distillation #343

Merged
merged 23 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 6 additions & 22 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@
# and limitations under the License.

import os
import re
from importlib import import_module
from re import Match
from typing import List, Union

from omegaconf import DictConfig, ListConfig
Expand All @@ -26,30 +24,16 @@
from anomalib.models.components import AnomalyModule


def _snake_to_camel_case(model_name: str) -> str:
"""Convert model name from snake case to camel case.
def _snake_to_pascal_case(model_name: str) -> str:
"""Convert model name from snake case to Pascal case.

Args:
model_name (str): Model name in snake case.

Returns:
str: Model name in camel case.
str: Model name in Pascal case.
"""

def _capitalize(match_object: Match) -> str:
"""Capitalizes regex matches to camel case.

Args:
match_object (Match): Input from regex substitute.

Returns:
str: Camel case string.
"""
ret = match_object.group(1).capitalize()
ret += match_object.group(3).capitalize() if match_object.group(3) is not None else ""
return ret

return re.sub(r"([a-z]+)(_([a-z]+))?", _capitalize, model_name)
return "".join([split.capitalize() for split in model_name.split("_")])


def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Expand All @@ -58,7 +42,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Works only when the convention for model naming is followed.

The convention for writing model classes is
`anomalib.models.<model_name>.model.<ModelName>Lightning`
`anomalib.models.<model_name>.lightning_model.<ModelName>Lightning`
`anomalib.models.stfpm.lightning_model.StfpmLightning`

Args:
Expand Down Expand Up @@ -86,7 +70,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:

if config.model.name in model_list:
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{_snake_to_camel_case(config.model.name)}Lightning")(config)
model = getattr(module, f"{_snake_to_pascal_case(config.model.name)}Lightning")(config)

else:
raise ValueError(f"Unknown model {config.model.name}!")
Expand Down
42 changes: 25 additions & 17 deletions anomalib/models/reverse_distillation/LICENSE
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
MIT License
Copyright (c) 2022 Intel Corporation
SPDX-License-Identifier: Apache-2.0

Copyright (c) 2022 hq-deng
Some files in this folder are based on the original Reverse Distillation implementation by hq-deng

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
Original license
----------------

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
MIT License

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Copyright (c) 2022 hq-deng

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
13 changes: 1 addition & 12 deletions anomalib/models/reverse_distillation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
"""Reverse Distillation Model."""

# Copyright (C) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
# SPDX-License-Identifier: Apache-2.0

from .lightning_model import ReverseDistillationLightning

Expand Down
33 changes: 16 additions & 17 deletions anomalib/models/reverse_distillation/anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,50 +24,49 @@ class AnomalyMapGenerator:
Args:
image_size (Union[ListConfig, Tuple]): Size of original image used for upscaling the anomaly map.
sigma (int): Standard deviation of the gaussian kernel used to smooth anomaly map.
mode (str, optional): Operation used to generate anomaly map. Options are `add` and `multiply`.
Defaults to "multiply".

Raises:
ValueError: In case modes other than multiply and add are passed.
"""

def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4):
def __init__(self, image_size: Union[ListConfig, Tuple], sigma: int = 4, mode: str = "multiply"):
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.sigma = sigma
self.kernel_size = 2 * int(4.0 * sigma + 0.5) + 1

def __call__(
self, student_features: List[Tensor], teacher_features: List[Tensor], mode: str = "multiply"
) -> Tensor:
if mode not in ("add", "multiply"):
raise ValueError(f"Found mode {mode}. Only multiply and add are supported.")
self.mode = mode

def __call__(self, student_features: List[Tensor], teacher_features: List[Tensor]) -> Tensor:
"""Computes anomaly map given encoder and decoder features.

Args:
student_features (List[Tensor]): List of encoder features
teacher_features (List[Tensor]): List of decoder features
mode (str, optional): Operation used to generate anomaly map. Options are `add` and `multiply`.
Defaults to "multiply".

Raises:
ValueError: In case modes other than multiply and add are passed.

Returns:
Tensor: Anomaly maps of length batch.
"""
if mode == "multiply":
if self.mode == "multiply":
anomaly_map = torch.ones(
[student_features[0].shape[0], 1, *self.image_size], device=student_features[0].device
) # b c h w
elif mode == "add":
elif self.mode == "add":
anomaly_map = torch.zeros(
[student_features[0].shape[0], 1, *self.image_size], device=student_features[0].device
)
else:
raise ValueError(f"Found mode {mode}. Only multiply and add are supported.")

for student_feature, teacher_feature in zip(student_features, teacher_features):
distance_map = 1 - F.cosine_similarity(student_feature, teacher_feature)
distance_map = torch.unsqueeze(distance_map, dim=1)
distance_map = F.interpolate(distance_map, size=self.image_size, mode="bilinear", align_corners=True)
if mode == "multiply":
if self.mode == "multiply":
anomaly_map *= distance_map
elif mode == "add":
elif self.mode == "add":
anomaly_map += distance_map
else:
raise ValueError(f"Operation {mode} not supported. Only ``add`` and ``multiply`` are supported")

anomaly_map = gaussian_blur2d(
anomaly_map, kernel_size=(self.kernel_size, self.kernel_size), sigma=(self.sigma, self.sigma)
Expand Down
1 change: 1 addition & 0 deletions anomalib/models/reverse_distillation/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ model:
beta1: 0.5
beta2: 0.99
normalization_method: min_max # options: [null, min_max, cdf]
anomaly_map_mode: multiply

metrics:
image:
Expand Down
11 changes: 8 additions & 3 deletions anomalib/models/reverse_distillation/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class ReverseDistillation(AnomalyModule):
layers (List[str]): Layers to extract features from the backbone CNN
"""

def __init__(self, input_size: Tuple[int, int], backbone: str, layers: List[str]):
def __init__(self, input_size: Tuple[int, int], backbone: str, layers: List[str], anomaly_map_mode: str):
super().__init__()
logger.info("Initializing Reverse Distillation Lightning model.")
ashwinvaidya17 marked this conversation as resolved.
Show resolved Hide resolved
self.model = ReverseDistillationModel(backbone=backbone, layers=layers, input_size=input_size)
self.model = ReverseDistillationModel(
backbone=backbone, layers=layers, input_size=input_size, anomaly_map_mode=anomaly_map_mode
)
self.loss = ReverseDistillationLoss()

def training_step(self, batch, _) -> Dict[str, Tensor]: # type: ignore
Expand Down Expand Up @@ -93,7 +95,10 @@ class ReverseDistillationLightning(ReverseDistillation):

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(
input_size=hparams.model.input_size, backbone=hparams.model.backbone, layers=hparams.model.layers
input_size=hparams.model.input_size,
backbone=hparams.model.backbone,
layers=hparams.model.layers,
anomaly_map_mode=hparams.model.anomaly_map_mode,
)
self.hparams: Union[DictConfig, ListConfig] # type: ignore
self.save_hyperparameters(hparams)
Expand Down
10 changes: 6 additions & 4 deletions anomalib/models/reverse_distillation/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ class ReverseDistillationModel(nn.Module):
backbone (str): Name of the backbone used for encoder and decoder
input_size (Tuple[int, int]): Size of input image
layers (List[str]): Name of layers from which the features are extracted.
anomaly_map_mode (str): Mode used to generate anomaly map. Options are between ``multiply`` and ``add``.
"""

def __init__(self, backbone: str, input_size: Tuple[int, int], layers: List[str]):
def __init__(self, backbone: str, input_size: Tuple[int, int], layers: List[str], anomaly_map_mode: str):
super().__init__()
self.tiler: Optional[Tiler] = None

encoder_backbone = getattr(torchvision.models, backbone)
# TODO replace with TIMM feature extractor
self.encoder = FeatureExtractor(backbone=encoder_backbone(pretrained=True), layers=layers)
self.bottleneck = get_bottleneck_layer(backbone)
self.encoder.eval()
self.decoder = get_decoder(backbone)

if self.tiler:
image_size = (self.tiler.tile_size_h, self.tiler.tile_size_w)
else:
image_size = input_size

self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(image_size))
self.anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(image_size), mode=anomaly_map_mode)

def forward(self, images: Tensor) -> Union[Tensor, Tuple[List[Tensor], List[Tensor]]]:
"""Forward-pass images to the network.
Expand All @@ -68,6 +68,8 @@ def forward(self, images: Tensor) -> Union[Tensor, Tuple[List[Tensor], List[Tens
Union[Tensor, Tuple[List[Tensor],List[Tensor]]]: Encoder and decoder features in training mode,
else anomaly maps.
"""
self.encoder.eval()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the benchmarking results still the same after this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am running it right now. But that was a good catch. I saw that the encoder was in training mode in the train step.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly apart from pixel level score, image level scores are lower
image
image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average results also seem to be a bit lower than those reported in the original paper (98.5% image AUROC, 97.8% pixel AUROC). For now I would suggest to merge this PR, but it would be good to investigate if this difference can be explained.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the official implementation give the exact same results reported in the paper?


if self.tiler:
images = self.tiler.tile(images)
encoder_features = self.encoder(images)
Expand All @@ -77,7 +79,7 @@ def forward(self, images: Tensor) -> Union[Tensor, Tuple[List[Tensor], List[Tens
if self.training:
output = encoder_features, decoder_features
else:
output = self.anomaly_map_generator(encoder_features, decoder_features, mode="add")
output = self.anomaly_map_generator(encoder_features, decoder_features)
if self.tiler:
output = self.tiler.untile(output)

Expand Down