diff --git a/anomalib/models/dfkde/model.py b/anomalib/models/dfkde/model.py index 4c59b71d2e..6a423b44a6 100644 --- a/anomalib/models/dfkde/model.py +++ b/anomalib/models/dfkde/model.py @@ -16,72 +16,13 @@ from typing import List, Union -import torch -import torchvision from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig -from torch import Tensor, nn +from torch import Tensor -from anomalib.models.components import AnomalyModule, FeatureExtractor +from anomalib.models.components import AnomalyModule -from .normality_model import NormalityModel - - -class DfkdeModel(nn.Module): - """DFKDE model. - - Args: - backbone (str): Pre-trained model backbone. - filter_count (int): Number of filters. - threshold_steepness (float): Threshold steepness for normality model. - threshold_offset (float): Threshold offset for normality model. - """ - - def __init__(self, backbone: str, filter_count: int, threshold_steepness: float, threshold_offset: float) -> None: - super().__init__() - self.backbone = getattr(torchvision.models, backbone) - self.feature_extractor = FeatureExtractor(backbone=self.backbone(pretrained=True), layers=["avgpool"]).eval() - - self.normality_model = NormalityModel( - filter_count=filter_count, - threshold_steepness=threshold_steepness, - threshold_offset=threshold_offset, - ) - - def get_features(self, batch: Tensor) -> Tensor: - """Extract features from the pretrained network. - - Args: - batch (Tensor): Image batch. - - Returns: - Tensor: Tensor containing extracted features. - """ - self.feature_extractor.eval() - layer_outputs = self.feature_extractor(batch) - layer_outputs = torch.cat(list(layer_outputs.values())).detach() - return layer_outputs - - def fit(self, embeddings: List[Tensor]): - """Fit normality model. - - Args: - embeddings (List[Tensor]): Embeddings to fit. - """ - _embeddings = torch.vstack(embeddings) - self.normality_model.fit(_embeddings) - - def forward(self, batch: Tensor) -> Tensor: - """Prediction by normality model. - - Args: - batch (Tensor): Input images. - - Returns: - Tensor: Predictions - """ - feature_vector = self.get_features(batch) - return self.normality_model.predict(feature_vector.view(feature_vector.shape[:2])) +from .torch_model import DfkdeModel class DfkdeLightning(AnomalyModule): @@ -96,8 +37,11 @@ def __init__(self, hparams: Union[DictConfig, ListConfig]): threshold_steepness = 0.05 threshold_offset = 12 - self.model: DfkdeModel = DfkdeModel( - hparams.model.backbone, hparams.model.max_training_points, threshold_steepness, threshold_offset + self.model = DfkdeModel( + backbone=hparams.model.backbone, + filter_count=hparams.model.max_training_points, + threshold_steepness=threshold_steepness, + threshold_offset=threshold_offset, ) self.embeddings: List[Tensor] = [] diff --git a/anomalib/models/dfkde/normality_model.py b/anomalib/models/dfkde/torch_model.py similarity index 75% rename from anomalib/models/dfkde/normality_model.py rename to anomalib/models/dfkde/torch_model.py index b9302805a4..2dde510a9f 100644 --- a/anomalib/models/dfkde/normality_model.py +++ b/anomalib/models/dfkde/torch_model.py @@ -15,18 +15,20 @@ # and limitations under the License. import random -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch +import torchvision from torch import Tensor, nn -from anomalib.models.components import PCA, GaussianKDE +from anomalib.models.components import PCA, FeatureExtractor, GaussianKDE -class NormalityModel(nn.Module): +class DfkdeModel(nn.Module): """Normality Model for the DFKDE algorithm. Args: + backbone (str): Pre-trained model backbone. n_comps (int, optional): Number of PCA components. Defaults to 16. pre_processing (str, optional): Preprocess features before passing to KDE. Options are between `norm` and `scale`. Defaults to "scale". @@ -37,6 +39,7 @@ class NormalityModel(nn.Module): def __init__( self, + backbone: str, n_comps: int = 16, pre_processing: str = "scale", filter_count: int = 40000, @@ -50,33 +53,51 @@ def __init__( self.threshold_steepness = threshold_steepness self.threshold_offset = threshold_offset + _backbone = getattr(torchvision.models, backbone) + self.feature_extractor = FeatureExtractor(backbone=_backbone(pretrained=True), layers=["avgpool"]).eval() + self.pca_model = PCA(n_components=self.n_components) self.kde_model = GaussianKDE() self.register_buffer("max_length", Tensor(torch.Size([]))) self.max_length = Tensor(torch.Size([])) - def fit(self, dataset: Tensor) -> bool: - """Fit a kde model to dataset. + def get_features(self, batch: Tensor) -> Tensor: + """Extract features from the pretrained network. + + Args: + batch (Tensor): Image batch. + + Returns: + Tensor: Tensor containing extracted features. + """ + self.feature_extractor.eval() + layer_outputs = self.feature_extractor(batch) + layer_outputs = torch.cat(list(layer_outputs.values())).detach() + return layer_outputs + + def fit(self, embeddings: List[Tensor]) -> bool: + """Fit a kde model to embeddings. Args: - dataset (Tensor): Input dataset to fit the model. + embeddings (Tensor): Input embeddings to fit the model. Returns: Boolean confirming whether the training is successful. """ + _embeddings = torch.vstack(embeddings) - if dataset.shape[0] < self.n_components: + if _embeddings.shape[0] < self.n_components: print("Not enough features to commit. Not making a model.") return False # if max training points is non-zero and smaller than number of staged features, select random subset - if self.filter_count and dataset.shape[0] > self.filter_count: + if self.filter_count and _embeddings.shape[0] > self.filter_count: # pylint: disable=not-callable - selected_idx = torch.tensor(random.sample(range(dataset.shape[0]), self.filter_count)) - selected_features = dataset[selected_idx] + selected_idx = torch.tensor(random.sample(range(_embeddings.shape[0]), self.filter_count)) + selected_features = _embeddings[selected_idx] else: - selected_features = dataset + selected_features = _embeddings feature_stack = self.pca_model.fit_transform(selected_features) feature_stack, max_length = self.preprocess(feature_stack) @@ -162,6 +183,15 @@ def to_probability(self, densities: Tensor) -> Tensor: return 1 / (1 + torch.exp(self.threshold_steepness * (densities - self.threshold_offset))) - def forward(self, features: Tensor) -> Tensor: - """Make module callable.""" - return self.predict(features) + def forward(self, batch: Tensor) -> Tensor: + """Prediction by normality model. + + Args: + batch (Tensor): Input images. + + Returns: + Tensor: Predictions + """ + + feature_vector = self.get_features(batch) + return self.predict(feature_vector.view(feature_vector.shape[:2]))