Skip to content

Commit

Permalink
🔨 Increase inference + openvino support (#122)
Browse files Browse the repository at this point in the history
* 🔨 Increase inference + openvino support

* Rename fit_normality_model to model

* Address DFKDE comments

* Addressed DFM pytlint comments

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
Co-authored-by: Samet Akcay <samet.akcay@intel.com>
  • Loading branch information
3 people authored Mar 8, 2022
1 parent aa49b7d commit 0d23715
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 117 deletions.
15 changes: 10 additions & 5 deletions anomalib/models/cflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,17 @@ def forward(self, images):
"""

activation = self.encoder(images)
self.encoder.eval()
self.decoders.eval()
with torch.no_grad():
activation = self.encoder(images)

distribution = [[] for _ in self.pool_layers]
distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers]

height: List[int] = []
width: List[int] = []
for layer_idx, layer in enumerate(self.pool_layers):
encoder_activations = activation[layer].detach() # BxCxHxW
encoder_activations = activation[layer] # BxCxHxW

batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
Expand Down Expand Up @@ -206,13 +209,15 @@ def forward(self, images):
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
p_u, log_jac_det = decoder(e_p, [c_p])
with torch.no_grad():
p_u, log_jac_det = decoder(e_p, [c_p])
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
distribution[layer_idx] = distribution[layer_idx] + log_prob.detach().tolist()
distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob))

output = self.anomaly_map_generator(distribution=distribution, height=height, width=width)
self.decoders.train()

return output.to(images.device)

Expand Down
88 changes: 67 additions & 21 deletions anomalib/models/dfkde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,86 @@
import torchvision
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from torch import Tensor
from torch import Tensor, nn

from anomalib.models.components import AnomalyModule, FeatureExtractor

from .normality_model import NormalityModel


class DfkdeModel(nn.Module):
"""DFKDE model.
Args:
backbone (str): Pre-trained model backbone.
filter_count (int): Number of filters.
threshold_steepness (float): Threshold steepness for normality model.
threshold_offset (float): Threshold offset for normality model.
"""

def __init__(self, backbone: str, filter_count: int, threshold_steepness: float, threshold_offset: float) -> None:
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()

self.normality_model = NormalityModel(
filter_count=filter_count,
threshold_steepness=threshold_steepness,
threshold_offset=threshold_offset,
)

def get_features(self, batch: Tensor) -> Tensor:
"""Extract features from the pretrained network.
Args:
batch (Tensor): Image batch.
Returns:
Tensor: Tensor containing extracted features.
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

def fit(self, embeddings: List[Tensor]):
"""Fit normality model.
Args:
embeddings (List[Tensor]): Embeddings to fit.
"""
_embeddings = torch.vstack(embeddings)
self.normality_model.fit(_embeddings)

def forward(self, batch: Tensor) -> Tensor:
"""Prediction by normality model.
Args:
batch (Tensor): Input images.
Returns:
Tensor: Predictions
"""
feature_vector = self.get_features(batch)
return self.normality_model.predict(feature_vector.view(feature_vector.shape[:2]))


class DfkdeLightning(AnomalyModule):
"""DFKDE: Deep Featured Kernel Density Estimation.
"""DFKDE: Deep Feature Kernel Density Estimation.
Args:
hparams (Union[DictConfig, ListConfig]): Model params
"""

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
self.threshold_steepness = 0.05
self.threshold_offset = 12

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()
threshold_steepness = 0.05
threshold_offset = 12

self.normality_model = NormalityModel(
filter_count=hparams.model.max_training_points,
threshold_steepness=self.threshold_steepness,
threshold_offset=self.threshold_offset,
self.model: DfkdeModel = DfkdeModel(
hparams.model.backbone, hparams.model.max_training_points, threshold_steepness, threshold_offset
)

self.automatic_optimization = False
self.embeddings: List[Tensor] = []

Expand All @@ -66,9 +119,7 @@ def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ
Deep CNN features.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze()
embedding = self.model.get_features(batch["image"]).squeeze()

# NOTE: `self.embedding` appends each batch embedding to
# store the training set embedding. We manually append these
Expand All @@ -81,8 +132,7 @@ def on_validation_start(self) -> None:
# NOTE: Previous anomalib versions fit Gaussian at the end of the epoch.
# This is not possible anymore with PyTorch Lightning v1.4.0 since validation
# is run within train epoch.
embeddings = torch.vstack(self.embeddings)
self.normality_model.fit(embeddings)
self.model.fit(self.embeddings)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of DFKDE.
Expand All @@ -95,10 +145,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Dictionary containing probability, prediction and ground truth values.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
feature_vector = torch.hstack(list(layer_outputs.values())).detach()
batch["pred_scores"] = self.normality_model.predict(feature_vector.view(feature_vector.shape[:2]))
batch["pred_scores"] = self.model(batch["image"])

return batch
34 changes: 27 additions & 7 deletions anomalib/models/dfm/dfm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import math

import torch
import torchvision
from torch import Tensor, nn

from anomalib.models.components import PCA, DynamicBufferModule
from anomalib.models.components import PCA, DynamicBufferModule, FeatureExtractor


class SingleClassGaussian(DynamicBufferModule):
Expand Down Expand Up @@ -83,16 +84,19 @@ class DFMModel(nn.Module):
"""Model for the DFM algorithm.
Args:
backbone (str): Pre-trained model backbone.
n_comps (float, optional): Ratio from which number of components for PCA are calculated. Defaults to 0.97.
score_type (str, optional): Scoring type. Options are `fre` and `nll`. Defaults to "fre".
"""

def __init__(self, n_comps: float = 0.97, score_type: str = "fre"):
def __init__(self, backbone: str, n_comps: float = 0.97, score_type: str = "fre"):
super().__init__()
self.backbone = getattr(torchvision.models, backbone)
self.n_components = n_comps
self.pca_model = PCA(n_components=self.n_components)
self.gaussian_model = SingleClassGaussian()
self.score_type = score_type
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()

def fit(self, dataset: Tensor) -> None:
"""Fit a pca transformation and a Gaussian model to dataset.
Expand Down Expand Up @@ -128,12 +132,28 @@ def score(self, features: Tensor) -> Tensor:

return score

def forward(self, dataset: Tensor) -> None:
"""Provides the same functionality as `fit`.
def get_features(self, batch: Tensor) -> Tensor:
"""Extract features from the pretrained network.
Transforms the input dataset based on singular values calculated earlier.
Args:
batch (Tensor): Image batch.
Returns:
Tensor: Tensor containing extracted features.
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

def forward(self, batch: Tensor) -> Tensor:
"""Computer score from input images.
Args:
dataset (Tensor): Input dataset
batch (Tensor): Input images
Returns:
Tensor: Scores
"""
self.fit(dataset)
feature_vector = self.get_features(batch)
return self.score(feature_vector.view(feature_vector.shape[:2]))
23 changes: 7 additions & 16 deletions anomalib/models/dfm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from typing import List, Union

import torch
import torchvision
from omegaconf import DictConfig, ListConfig
from torch import Tensor

from anomalib.models.components import AnomalyModule, FeatureExtractor
from anomalib.models.components import AnomalyModule

from .dfm_model import DFMModel

Expand All @@ -32,10 +31,9 @@ class DfmLightning(AnomalyModule):
def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()

self.dfm_model = DFMModel(n_comps=hparams.model.pca_level, score_type=hparams.model.score_type)
self.model: DFMModel = DFMModel(
backbone=hparams.model.backbone, n_comps=hparams.model.pca_level, score_type=hparams.model.score_type
)
self.automatic_optimization = False
self.embeddings: List[Tensor] = []

Expand All @@ -56,10 +54,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Deep CNN features.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze()
embedding = self.model.get_features(batch["image"]).squeeze()

# NOTE: `self.embedding` appends each batch embedding to
# store the training set embedding. We manually append these
Expand All @@ -73,7 +68,7 @@ def on_validation_start(self) -> None:
# This is not possible anymore with PyTorch Lightning v1.4.0 since validation
# is run within train epoch.
embeddings = torch.vstack(self.embeddings)
self.dfm_model.fit(embeddings)
self.model.fit(embeddings)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of DFM.
Expand All @@ -86,10 +81,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Dictionary containing FRE anomaly scores and ground-truth.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
feature_vector = torch.hstack(list(layer_outputs.values())).detach()
batch["pred_scores"] = self.dfm_model.score(feature_vector.view(feature_vector.shape[:2]))
batch["pred_scores"] = self.model(batch["image"])

return batch
Loading

0 comments on commit 0d23715

Please sign in to comment.