Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model implementations #225

Merged
merged 9 commits into from
Apr 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

# TODO(AlexanderDokuchaev): Workaround of wrapping by NNCF.
# Can't not wrap `spatial_softmax2d` if use import_module.
from anomalib.models.padim.model import PadimLightning # noqa: F401
from anomalib.models.padim.lightning_model import PadimLightning # noqa: F401
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this import needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was added by Alexander. I think it's needed for NNCF support

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah looks like it. But I don't understand how/why it's needed. It is not used anywhere



def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Expand Down Expand Up @@ -62,7 +62,7 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
raise ValueError(f"Unknown model {config.model.name} for OpenVINO model!")
else:
if config.model.name in torch_model_list:
module = import_module(f"anomalib.models.{config.model.name}.model")
module = import_module(f"anomalib.models.{config.model.name}")
model = getattr(module, f"{config.model.name.capitalize()}Lightning")
else:
raise ValueError(f"Unknown model {config.model.name}!")
Expand Down
9 changes: 5 additions & 4 deletions anomalib/models/cflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.

[CW-AD](https://arxiv.org/pdf/2107.12571v1.pdf)
"""
"""Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows."""

# Copyright (C) 2020 Intel Corporation
#
Expand All @@ -16,3 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from .lightning_model import CflowLightning

__all__ = ["CflowLightning"]
97 changes: 97 additions & 0 deletions anomalib/models/cflow/anomaly_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Anomaly Map Generator for CFlow model implementation."""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import List, Tuple, Union, cast

import torch
import torch.nn.functional as F
from omegaconf import ListConfig
from torch import Tensor


class AnomalyMapGenerator:
"""Generate Anomaly Heatmap."""

def __init__(
self,
image_size: Union[ListConfig, Tuple],
pool_layers: List[str],
):
self.distance = torch.nn.PairwiseDistance(p=2, keepdim=True)
self.image_size = image_size if isinstance(image_size, tuple) else tuple(image_size)
self.pool_layers: List[str] = pool_layers

def compute_anomaly_map(
self, distribution: Union[List[Tensor], List[List]], height: List[int], width: List[int]
) -> Tensor:
"""Compute the layer map based on likelihood estimation.

Args:
distribution: Probability distribution for each decoder block
height: blocks height
width: blocks width

Returns:
Final Anomaly Map

"""

test_map: List[Tensor] = []
for layer_idx in range(len(self.pool_layers)):
test_norm = torch.tensor(distribution[layer_idx], dtype=torch.double) # pylint: disable=not-callable
test_norm -= torch.max(test_norm) # normalize likelihoods to (-Inf:0] by subtracting a constant
test_prob = torch.exp(test_norm) # convert to probs in range [0:1]
test_mask = test_prob.reshape(-1, height[layer_idx], width[layer_idx])
# upsample
test_map.append(
F.interpolate(
test_mask.unsqueeze(1), size=self.image_size, mode="bilinear", align_corners=True
).squeeze()
)
# score aggregation
score_map = torch.zeros_like(test_map[0])
for layer_idx in range(len(self.pool_layers)):
score_map += test_map[layer_idx]
score_mask = score_map
# invert probs to anomaly scores
anomaly_map = score_mask.max() - score_mask

return anomaly_map

def __call__(self, **kwargs: Union[List[Tensor], List[int], List[List]]) -> Tensor:
"""Returns anomaly_map.

Expects `distribution`, `height` and 'width' keywords to be passed explicitly

Example
>>> anomaly_map_generator = AnomalyMapGenerator(image_size=tuple(hparams.model.input_size),
>>> pool_layers=pool_layers)
>>> output = self.anomaly_map_generator(distribution=dist, height=height, width=width)

Raises:
ValueError: `distribution`, `height` and 'width' keys are not found

Returns:
torch.Tensor: anomaly map
"""
if not ("distribution" in kwargs and "height" in kwargs and "width" in kwargs):
raise KeyError(f"Expected keys `distribution`, `height` and `width`. Found {kwargs.keys()}")

# placate mypy
distribution: List[Tensor] = cast(List[Tensor], kwargs["distribution"])
height: List[int] = cast(List[int], kwargs["height"])
width: List[int] = cast(List[int], kwargs["width"])
return self.compute_anomaly_map(distribution, height, width)
156 changes: 156 additions & 0 deletions anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""CFLOW: Real-Time Unsupervised Anomaly Detection via Conditional Normalizing Flows.

https://arxiv.org/pdf/2107.12571v1.pdf
"""

# Copyright (C) 2020 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.

import einops
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping
from torch import optim

from anomalib.models.cflow.torch_model import CflowModel
from anomalib.models.cflow.utils import get_logp, positional_encoding_2d
from anomalib.models.components import AnomalyModule

__all__ = ["CflowLightning"]


class CflowLightning(AnomalyModule):
"""PL Lightning Module for the CFLOW algorithm."""

def __init__(self, hparams):
super().__init__(hparams)

self.model: CflowModel = CflowModel(hparams)
self.loss_val = 0
self.automatic_optimization = False

def configure_callbacks(self):
"""Configure model-specific callbacks."""
early_stopping = EarlyStopping(
monitor=self.hparams.model.early_stopping.metric,
patience=self.hparams.model.early_stopping.patience,
mode=self.hparams.model.early_stopping.mode,
)
return [early_stopping]

def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures optimizers for each decoder.

Returns:
Optimizer: Adam optimizer for each decoder
"""
decoders_parameters = []
for decoder_idx in range(len(self.model.pool_layers)):
decoders_parameters.extend(list(self.model.decoders[decoder_idx].parameters()))

optimizer = optim.Adam(
params=decoders_parameters,
lr=self.hparams.model.lr,
)
return optimizer

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Training Step of CFLOW.

For each batch, decoder layers are trained with a dynamic fiber batch size.
Training step is performed manually as multiple training steps are involved
per batch of input images

Args:
batch: Input batch
_: Index of the batch.

Returns:
Loss value for the batch

"""
opt = self.optimizers()
self.model.encoder.eval()

images = batch["image"]
activation = self.model.encoder(images)
avg_loss = torch.zeros([1], dtype=torch.float64).to(images.device)

height = []
width = []
for layer_idx, layer in enumerate(self.model.pool_layers):
encoder_activations = activation[layer].detach() # BxCxHxW

batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
embedding_length = batch_size * image_size # number of rows in the conditional vector

height.append(im_height)
width.append(im_width)
# repeats positional encoding for the entire batch 1 C H W to B C H W
pos_encoding = einops.repeat(
positional_encoding_2d(self.model.condition_vector, im_height, im_width).unsqueeze(0),
"b c h w-> (tile b) c h w",
tile=batch_size,
).to(images.device)
c_r = einops.rearrange(pos_encoding, "b c h w -> (b h w) c") # BHWxP
e_r = einops.rearrange(encoder_activations, "b c h w -> (b h w) c") # BHWxC
perm = torch.randperm(embedding_length) # BHW
decoder = self.model.decoders[layer_idx].to(images.device)

fiber_batches = embedding_length // self.model.fiber_batch_size # number of fiber batches
assert fiber_batches > 0, "Make sure we have enough fibers, otherwise decrease N or batch-size!"

for batch_num in range(fiber_batches): # per-fiber processing
opt.zero_grad()
if batch_num < (fiber_batches - 1):
idx = torch.arange(
batch_num * self.model.fiber_batch_size, (batch_num + 1) * self.model.fiber_batch_size
)
else: # When non-full batch is encountered batch_num * N will go out of bounds
idx = torch.arange(batch_num * self.model.fiber_batch_size, embedding_length)
# get random vectors
c_p = c_r[perm[idx]] # NxP
e_p = e_r[perm[idx]] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
p_u, log_jac_det = decoder(e_p, [c_p])
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
loss = -F.logsigmoid(log_prob)
self.manual_backward(loss.mean())
opt.step()
avg_loss += loss.sum()

return {"loss": avg_loss}

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of CFLOW.

Similar to the training step, encoder features
are extracted from the CNN for each batch, and anomaly
map is computed.

Args:
batch: Input batch
_: Index of the batch.

Returns:
Dictionary containing images, anomaly maps, true labels and masks.
These are required in `validation_epoch_end` for feature concatenation.

"""
batch["anomaly_maps"] = self.model(batch["image"])

return batch
Loading