Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🔨 Increase inference + openvino support #122

Merged
merged 4 commits into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions anomalib/models/cflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,17 @@ def forward(self, images):

"""

activation = self.encoder(images)
self.encoder.eval()
self.decoders.eval()
with torch.no_grad():
activation = self.encoder(images)

distribution = [[] for _ in self.pool_layers]
distribution = [torch.Tensor(0).to(images.device) for _ in self.pool_layers]

height: List[int] = []
width: List[int] = []
for layer_idx, layer in enumerate(self.pool_layers):
encoder_activations = activation[layer].detach() # BxCxHxW
encoder_activations = activation[layer] # BxCxHxW

batch_size, dim_feature_vector, im_height, im_width = encoder_activations.size()
image_size = im_height * im_width
Expand Down Expand Up @@ -206,13 +209,15 @@ def forward(self, images):
c_p = c_r[idx] # NxP
e_p = e_r[idx] # NxC
# decoder returns the transformed variable z and the log Jacobian determinant
p_u, log_jac_det = decoder(e_p, [c_p])
with torch.no_grad():
p_u, log_jac_det = decoder(e_p, [c_p])
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved
#
decoder_log_prob = get_logp(dim_feature_vector, p_u, log_jac_det)
log_prob = decoder_log_prob / dim_feature_vector # likelihood per dim
distribution[layer_idx] = distribution[layer_idx] + log_prob.detach().tolist()
distribution[layer_idx] = torch.cat((distribution[layer_idx], log_prob))

output = self.anomaly_map_generator(distribution=distribution, height=height, width=width)
self.decoders.train()

return output.to(images.device)

Expand Down
88 changes: 68 additions & 20 deletions anomalib/models/dfkde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,71 @@
import torchvision
from omegaconf.dictconfig import DictConfig
from omegaconf.listconfig import ListConfig
from torch import Tensor
from torch import Tensor, nn

from anomalib.models.components import AnomalyModule, FeatureExtractor

from .normality_model import NormalityModel


class DfkdeModel(nn.Module):
"""DFKDR model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: DFKDE


Args:
backbone (nn.Module): Feature extraction backbone.
filter_count (int): Number of filters.
threshold_steepness (float): Threshold steepness for normality model.
threshold_offset (float): Threshold offset for normality model.
"""

def __init__(
self, backbone: nn.Module, filter_count: int, threshold_steepness: float, threshold_offset: float
) -> None:
super().__init__()
self.feature_extractor = FeatureExtractor(backbone=backbone(pretrained=True), layers=["avgpool"]).eval()

self.normality_model = NormalityModel(
filter_count=filter_count,
threshold_steepness=threshold_steepness,
threshold_offset=threshold_offset,
)
samet-akcay marked this conversation as resolved.
Show resolved Hide resolved

def get_features(self, batch: Tensor) -> Tensor:
"""Extract features from the pretrained network.

Args:
batch (Tensor): Image batch.

Returns:
Tensor: Tensor containing extracted features.
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

def fit(self, embeddings: List[Tensor]):
"""Fit normality model.

Args:
embeddings (List[Tensor]): Embeddings to fit.
"""
_embeddings = torch.vstack(embeddings)
self.normality_model.fit(_embeddings)

def forward(self, batch: Tensor) -> Tensor:
"""Prediction by normality model.

Args:
batch (Tensor): Input images.

Returns:
Tensor: Predictions
"""
feature_vector = self.get_features(batch)
return self.normality_model.predict(feature_vector.view(feature_vector.shape[:2]))


class DfkdeLightning(AnomalyModule):
"""DFKDE: Deep Featured Kernel Density Estimation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: Deep Feature


Expand All @@ -36,17 +94,14 @@ class DfkdeLightning(AnomalyModule):

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
self.threshold_steepness = 0.05
self.threshold_offset = 12
threshold_steepness = 0.05
threshold_offset = 12

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()

self.normality_model = NormalityModel(
filter_count=hparams.model.max_training_points,
threshold_steepness=self.threshold_steepness,
threshold_offset=self.threshold_offset,
backbone = getattr(torchvision.models, hparams.model.backbone)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could create the backbone inside DfkdeModel and just pass the hparams.model.backbone to the constructor of that class. That would be more in line with other models (see Padim for example).

self.model: DfkdeModel = DfkdeModel(
backbone, hparams.model.max_training_points, threshold_steepness, threshold_offset
)

self.automatic_optimization = False
self.embeddings: List[Tensor] = []

Expand All @@ -66,9 +121,7 @@ def training_step(self, batch, _batch_idx): # pylint: disable=arguments-differ
Deep CNN features.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze()
embedding = self.model.get_features(batch["image"]).squeeze()

# NOTE: `self.embedding` appends each batch embedding to
# store the training set embedding. We manually append these
Expand All @@ -81,8 +134,7 @@ def on_validation_start(self) -> None:
# NOTE: Previous anomalib versions fit Gaussian at the end of the epoch.
# This is not possible anymore with PyTorch Lightning v1.4.0 since validation
# is run within train epoch.
embeddings = torch.vstack(self.embeddings)
self.normality_model.fit(embeddings)
self.model.fit(self.embeddings)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of DFKDE.
Expand All @@ -95,10 +147,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Dictionary containing probability, prediction and ground truth values.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
feature_vector = torch.hstack(list(layer_outputs.values())).detach()
batch["pred_scores"] = self.normality_model.predict(feature_vector.view(feature_vector.shape[:2]))
batch["pred_scores"] = self.model(batch["image"])

return batch
31 changes: 24 additions & 7 deletions anomalib/models/dfm/dfm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import torch
from torch import Tensor, nn

from anomalib.models.components import PCA, DynamicBufferModule
from anomalib.models.components import PCA, DynamicBufferModule, FeatureExtractor


class SingleClassGaussian(DynamicBufferModule):
Expand Down Expand Up @@ -87,12 +87,13 @@ class DFMModel(nn.Module):
score_type (str, optional): Scoring type. Options are `fre` and `nll`. Defaults to "fre".
"""

def __init__(self, n_comps: float = 0.97, score_type: str = "fre"):
def __init__(self, backbone: nn.Module, n_comps: float = 0.97, score_type: str = "fre"):
super().__init__()
self.n_components = n_comps
self.pca_model = PCA(n_components=self.n_components)
self.gaussian_model = SingleClassGaussian()
self.score_type = score_type
self.feature_extractor = FeatureExtractor(backbone=backbone(pretrained=True), layers=["avgpool"]).eval()

def fit(self, dataset: Tensor) -> None:
"""Fit a pca transformation and a Gaussian model to dataset.
Expand Down Expand Up @@ -128,12 +129,28 @@ def score(self, features: Tensor) -> Tensor:

return score

def forward(self, dataset: Tensor) -> None:
"""Provides the same functionality as `fit`.
def get_features(self, batch: Tensor) -> Tensor:
"""Extract features from the pretrained network.

Transforms the input dataset based on singular values calculated earlier.
Args:
batch (Tensor): Image batch.

Returns:
Tensor: Tensor containing extracted features.
"""
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch)
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

def forward(self, batch: Tensor) -> Tensor:
"""Computer score from input images.

Args:
dataset (Tensor): Input dataset
batch (Tensor): Input images

Returns:
Tensor: Scores
"""
self.fit(dataset)
feature_vector = self.get_features(batch)
return self.score(feature_vector.view(feature_vector.shape[:2]))
23 changes: 8 additions & 15 deletions anomalib/models/dfm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from omegaconf import DictConfig, ListConfig
from torch import Tensor

from anomalib.models.components import AnomalyModule, FeatureExtractor
from anomalib.models.components import AnomalyModule

from .dfm_model import DFMModel

Expand All @@ -32,10 +32,10 @@ class DfmLightning(AnomalyModule):
def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)

self.backbone = getattr(torchvision.models, hparams.model.backbone)
self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval()

self.dfm_model = DFMModel(n_comps=hparams.model.pca_level, score_type=hparams.model.score_type)
backbone = getattr(torchvision.models, hparams.model.backbone)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, let's pass the string from the hparams and then instantiate the backbone inside the pytorch model.

self.model: DFMModel = DFMModel(
backbone=backbone, n_comps=hparams.model.pca_level, score_type=hparams.model.score_type
)
self.automatic_optimization = False
self.embeddings: List[Tensor] = []

Expand All @@ -56,10 +56,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Deep CNN features.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
embedding = torch.hstack(list(layer_outputs.values())).detach().squeeze()
embedding = self.model.get_features(batch["image"]).squeeze()

# NOTE: `self.embedding` appends each batch embedding to
# store the training set embedding. We manually append these
Expand All @@ -73,7 +70,7 @@ def on_validation_start(self) -> None:
# This is not possible anymore with PyTorch Lightning v1.4.0 since validation
# is run within train epoch.
embeddings = torch.vstack(self.embeddings)
self.dfm_model.fit(embeddings)
self.model.fit(embeddings)

def validation_step(self, batch, _): # pylint: disable=arguments-differ
"""Validation Step of DFM.
Expand All @@ -86,10 +83,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
Returns:
Dictionary containing FRE anomaly scores and ground-truth.
"""

self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
feature_vector = torch.hstack(list(layer_outputs.values())).detach()
batch["pred_scores"] = self.dfm_model.score(feature_vector.view(feature_vector.shape[:2]))
batch["pred_scores"] = self.model(batch["image"])

return batch
Loading