Skip to content

Commit

Permalink
Load model did not work correctly as DFMModel did not inherit (#5)
Browse files Browse the repository at this point in the history
* Load model did not work correctly as DFMModel did not inherit
`nn.Module`

- Change: `DFMModel` is now a subclass of `nn.Module`
- Change: `SingleClassGaussian` is not a subclass of `DynamicBufferModule`
- Fix: Missing/incorrect doc strings in `pca.py` and `dfm_model`
- Fix: Load model test in `test_model.py` compares metrics for
    classification models.
- Fix: Rename `SingleclassGaussian` to `SingleClassGaussian`
- Fix: Incorrect input region in `generate_random_anomaly_image` in
	dummy dataset helpers

* Address PR comments

Co-authored-by: Ashwin Vaidya <ashwinitinvaidya@gmail.com>
  • Loading branch information
ashwinvaidya17 and Ashwin Vaidya committed Nov 26, 2021
1 parent cba7d28 commit ada55e3
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 97 deletions.
86 changes: 68 additions & 18 deletions anomalib/core/model/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,62 +16,112 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from typing import Union

import torch
from torch import Tensor

from anomalib.core.model.dynamic_module import DynamicBufferModule


class PCA(DynamicBufferModule):
"""
Principle Component Analysis (PCA)
Args:
n_components (float): Number of components. Can be either integer number of components
or a ratio between 0-1.
"""

def __init__(self, n_components: int):
def __init__(self, n_components: Union[float, int]):
super().__init__()
self.n_components = n_components

self.register_buffer("singular_vectors", torch.Tensor())
self.register_buffer("mean", torch.Tensor())
self.register_buffer("singular_vectors", Tensor())
self.register_buffer("mean", Tensor())
self.register_buffer("num_components", Tensor())

self.singular_vectors: Tensor
self.singular_values: Tensor
self.mean: Tensor
self.num_components: Tensor

def fit(self, dataset: Tensor) -> None:
"""
Fits the PCA model to the dataset
Args:
dataset (Tensor): Input dataset to fit the model.
"""
mean = dataset.mean(dim=0)
dataset -= mean

_, sig, v_h = torch.linalg.svd(dataset.double())
num_components: int
if self.n_components <= 1:
variance_ratios = torch.cumsum(sig * sig, dim=0) / torch.sum(sig * sig)
num_components = torch.nonzero(variance_ratios >= self.n_components)[0]
else:
num_components = int(self.n_components)

self.num_components = Tensor([num_components])

self.singular_vectors: torch.Tensor
self.mean: torch.Tensor
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components].float()
self.singular_values = sig[:num_components].float()
self.mean = mean

def fit_transform(self, dataset: torch.Tensor) -> torch.Tensor:
def fit_transform(self, dataset: Tensor) -> Tensor:
"""
Args:
dataset: torch.Tensor:
dataset (Tensor): Dataset to which the PCA if fit and transformed
Returns:
Returns: Transformed dataset
"""
mean = dataset.mean(dim=0)
dataset -= mean
num_components = int(self.n_components)
self.num_components = Tensor([num_components])

self.singular_vectors = torch.svd(dataset)[-1]
v_h = torch.linalg.svd(dataset)[-1]
self.singular_vectors = v_h.transpose(-2, -1)[:, :num_components]
self.mean = mean

return torch.matmul(dataset, self.singular_vectors[:, : self.n_components])
return torch.matmul(dataset, self.singular_vectors)

def transform(self, features: torch.Tensor) -> torch.Tensor:
def transform(self, features: Tensor) -> Tensor:
"""
Transforms the features based on singular vectors calculated earlier.
Args:
features: torch.Tensor:
Returns:
features (Tensor): Input features
Returns: Transformed features
"""

features -= self.mean
return torch.matmul(features, self.singular_vectors[:, : self.n_components])
return torch.matmul(features, self.singular_vectors)

def forward(self, features: torch.Tensor) -> torch.Tensor:
def inverse_transform(self, features: Tensor) -> Tensor:
"""
Inverses the transformed features
Args:
features: torch.Tensor:
features (Tensor): Transformed features
Returns:
Returns: Inverse features
"""
inv_features = torch.matmul(features, self.singular_vectors.transpose(-2, -1))
return inv_features

def forward(self, features: Tensor) -> Tensor:
"""
Transforms the features
Args:
features (Tensor): Input features
Returns: Transformed features
"""
return self.transform(features)
108 changes: 67 additions & 41 deletions anomalib/models/dfm/dfm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,107 +16,133 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import numpy as np
import math

import torch
from sklearn.decomposition import PCA
from torch import Tensor, nn

from anomalib.core.model.dynamic_module import DynamicBufferModule
from anomalib.core.model.pca import PCA


class SingleclassGaussian:
class SingleClassGaussian(DynamicBufferModule):
"""
Model Gaussian distribution over a set of points
"""

def __init__(self):
self.mean_vec = None
self.u_mat = None
self.sigma_mat = None
super().__init__()
self.register_buffer("mean_vec", Tensor())
self.register_buffer("u_mat", Tensor())
self.register_buffer("sigma_mat", Tensor())

def fit(self, dataset):
self.mean_vec: Tensor
self.u_mat: Tensor
self.sigma_mat: Tensor

def fit(self, dataset: Tensor) -> None:
"""
Fit a Gaussian model to dataset X.
Covariance matrix is not calculated directly using:
C = X.X^T
Instead, it is represented in terms of the Singular Value Decomposition of X:
X = U.S.V^T
Hence,
C = U.S^2.U^T
This simplifies the calculation of the log-likelihood without requiring full matrix inversion.
Covariance matrix is not calculated directly using:
C = X.X^T
Instead, it is represented in terms of the Singular Value Decomposition of X:
X = U.S.V^T
Hence,
C = U.S^2.U^T
This simplifies the calculation of the log-likelihood without requiring full matrix inversion.
Args:
dataset: Input dataset to fit the model.
dataset: torch.Tensor:
Returns:
dataset (Tensor): Input dataset to fit the model.
"""

num_samples = dataset.shape[1]
self.mean_vec = torch.mean(dataset, dim=1)
data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / torch.sqrt(torch.Tensor([num_samples]))
data_centered = (dataset - self.mean_vec.reshape(-1, 1)) / math.sqrt(num_samples)
self.u_mat, self.sigma_mat, _ = torch.linalg.svd(data_centered, full_matrices=False)

def score_samples(self, features):
def score_samples(self, features: Tensor) -> Tensor:
"""
Compute the NLL (negative log likelihood) scores
Args:
x: semantic features on which density modeling is performed.
features (Tensor): semantic features on which density modeling is performed.
Returns:
nll: numpy array of scores
nll (Tensor): Torch tensor of scores
"""
features_transformed = torch.matmul(features - self.mean_vec, self.u_mat / self.sigma_mat)
nll = torch.sum(features_transformed * features_transformed, dim=1) + 2 * np.sum(np.log(self.sigma_mat))
nll = torch.sum(features_transformed * features_transformed, dim=1) + 2 * torch.sum(torch.log(self.sigma_mat))
return nll

def forward(self, dataset: Tensor) -> None:
"""
Provides the same functionality as `fit`. Transforms the input dataset based on singular values calculated
earlier.
Args:
dataset (Tensor): Input dataset
"""
self.fit(dataset)


class DFMModel:
class DFMModel(nn.Module):
"""
Model for the DFM algorithm
Args:
n_comps (float, optional): Ratio from which number of components for PCA are calculated. Defaults to 0.97.
score_type (str, optional): Scoring type. Options are `fre` and `nll`. Defaults to "fre".
"""

def __init__(self, n_comps: float = 0.97, score_type: str = "fre"):
super().__init__()
self.n_components = n_comps
self.pca_model = PCA(n_components=self.n_components)
self.gaussian_model = SingleclassGaussian()
self.gaussian_model = SingleClassGaussian()
self.score_type = score_type

def fit(self, dataset: torch.Tensor):
def fit(self, dataset: Tensor) -> None:
"""
Fit a pca transformation and a Gaussian model to dataset
Args:
dataset: Input dataset to fit the model.
dataset: torch.Tensor:
Returns:
dataset (Tensor): Input dataset to fit the model.
"""

selected_features = dataset.cpu().numpy()
self.pca_model.fit(selected_features)
features_reduced = torch.Tensor(self.pca_model.transform(selected_features))
self.pca_model.fit(dataset)
features_reduced = self.pca_model.transform(dataset)
self.gaussian_model.fit(features_reduced.T)

def score(self, sem_feats: torch.Tensor) -> np.array:
def score(self, features: Tensor) -> Tensor:
"""
Compute the PCA-based feature reconstruction error (FRE) scores and
the Gaussian density-based NLL scores
Args:
sem_feats: semantic features on which PCA and density modeling is performed.
features (torch.Tensor): semantic features on which PCA and density modeling is performed.
Returns:
score: numpy array of scores
score (Tensor): numpy array of scores
"""
feats_orig = sem_feats.cpu().numpy()
feats_projected = self.pca_model.transform(feats_orig)
feats_projected = self.pca_model.transform(features)
if self.score_type == "nll":
score = self.gaussian_model.score_samples(feats_projected)
elif self.score_type == "fre":
feats_reconstructed = self.pca_model.inverse_transform(feats_projected)
score = np.sum(np.square(feats_orig - feats_reconstructed), axis=1)
score = torch.sum(torch.square(features - feats_reconstructed), dim=1)
else:
raise ValueError(f"unsupported score type: {self.score_type}")

return score

def forward(self, dataset: Tensor) -> None:
"""
Provides the same functionality as `fit`. Transforms the input dataset based on singular values calculated
earlier.
Args:
dataset (Tensor): Input dataset
"""
self.fit(dataset)
16 changes: 5 additions & 11 deletions anomalib/models/dfm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,10 @@

import torch
from omegaconf import DictConfig, ListConfig
from torch import Tensor
from torchvision.models import resnet18

from anomalib.core.model import AnomalyModule
from anomalib.core.model.feature_extractor import FeatureExtractor
from anomalib.core.results import ClassificationResults
from anomalib.models.dfm.dfm_model import DFMModel


Expand All @@ -36,14 +34,9 @@ class DfmLightning(AnomalyModule):

def __init__(self, hparams: Union[DictConfig, ListConfig]):
super().__init__(hparams)
self.save_hyperparameters(hparams)
self.threshold_steepness = 0.05
self.threshold_offset = 12

self.feature_extractor = FeatureExtractor(backbone=resnet18(pretrained=True), layers=["avgpool"]).eval()

self.dfm_model = DFMModel(n_comps=hparams.model.pca_level, score_type=hparams.model.score_type)
self.results = ClassificationResults()
self.automatic_optimization = False

@staticmethod
Expand All @@ -58,8 +51,8 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
For each batch, features are extracted from the CNN.
Args:
batch: Dict: Input batch
batch_idx: int: Index of the batch.
batch: Input batch
_: Index of the batch.
Returns:
Deep CNN features.
Expand All @@ -71,7 +64,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
feature_vector = torch.hstack(list(layer_outputs.values())).detach().squeeze()
return {"feature_vector": feature_vector}

def training_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> None:
def training_epoch_end(self, outputs: List[Dict[str, Any]]) -> None:
"""Fit a KDE model on deep CNN features.
Args:
Expand Down Expand Up @@ -102,5 +95,6 @@ def validation_step(self, batch, _): # pylint: disable=arguments-differ
self.feature_extractor.eval()
layer_outputs = self.feature_extractor(batch["image"])
feature_vector = torch.hstack(list(layer_outputs.values())).detach()
batch["pred_scores"] = torch.from_numpy(self.dfm_model.score(feature_vector.view(feature_vector.shape[:2])))
batch["pred_scores"] = self.dfm_model.score(feature_vector.view(feature_vector.shape[:2]))

return batch
2 changes: 1 addition & 1 deletion tests/helpers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generate_random_anomaly_image(

image: np.ndarray = np.full((image_height, image_width, 3), 255, dtype=np.uint8)

input_region = [0, 0, image_width, image_height]
input_region = [0, 0, image_width - 1, image_height - 1]

for shape in shapes:
shape_image = random_shapes(input_region, (image_height, image_width), max_shapes=max_shapes, shape=shape)
Expand Down
Loading

0 comments on commit ada55e3

Please sign in to comment.