Skip to content

Commit

Permalink
🛠 Fix config files and refactor dfkde (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Jul 13, 2022
1 parent 8364577 commit ddd1b50
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 43 deletions.
4 changes: 2 additions & 2 deletions anomalib/models/dfkde/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ metrics:

visualization:
show_images: False # show images on the screen
save_images: True # save images to the file system
log_images: True # log images to the available loggers (if any)
save_images: False # save images to the file system
log_images: False # log images to the available loggers (if any)
image_save_path: null # path to which images will be saved
mode: full # options: ["full", "simple"]

Expand Down
78 changes: 39 additions & 39 deletions anomalib/models/dfkde/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,29 @@ def get_features(self, batch: Tensor) -> Tensor:
layer_outputs = torch.cat(list(layer_outputs.values())).detach()
return layer_outputs

def pre_process(self, feature_stack: Tensor, max_length: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Pre-process the CNN features.
Args:
feature_stack (Tensor): Features extracted from CNN
max_length (Optional[Tensor]): Used to unit normalize the feature_stack vector. If ``max_len`` is not
provided, the length is calculated from the ``feature_stack``. Defaults to None.
Returns:
(Tuple): Stacked features and length
"""

if max_length is None:
max_length = torch.max(torch.linalg.norm(feature_stack, ord=2, dim=1))

if self.pre_processing == "norm":
feature_stack /= torch.linalg.norm(feature_stack, ord=2, dim=1)[:, None]
elif self.pre_processing == "scale":
feature_stack /= max_length
else:
raise RuntimeError("Unknown pre-processing mode. Available modes are: Normalized and Scale.")
return feature_stack, max_length

def fit(self, embeddings: List[Tensor]) -> bool:
"""Fit a kde model to embeddings.
Expand All @@ -105,36 +128,13 @@ def fit(self, embeddings: List[Tensor]) -> bool:
selected_features = _embeddings

feature_stack = self.pca_model.fit_transform(selected_features)
feature_stack, max_length = self.preprocess(feature_stack)
feature_stack, max_length = self.pre_process(feature_stack)
self.max_length = max_length
self.kde_model.fit(feature_stack)

return True

def preprocess(self, feature_stack: Tensor, max_length: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Pre-process the CNN features.
Args:
feature_stack (Tensor): Features extracted from CNN
max_length (Optional[Tensor]): Used to unit normalize the feature_stack vector. If ``max_len`` is not
provided, the length is calculated from the ``feature_stack``. Defaults to None.
Returns:
(Tuple): Stacked features and length
"""

if max_length is None:
max_length = torch.max(torch.linalg.norm(feature_stack, ord=2, dim=1))

if self.pre_processing == "norm":
feature_stack /= torch.linalg.norm(feature_stack, ord=2, dim=1)[:, None]
elif self.pre_processing == "scale":
feature_stack /= max_length
else:
raise RuntimeError("Unknown pre-processing mode. Available modes are: Normalized and Scale.")
return feature_stack, max_length

def evaluate(self, features: Tensor, as_log_likelihood: Optional[bool] = False) -> Tensor:
def compute_kde_scores(self, features: Tensor, as_log_likelihood: Optional[bool] = False) -> Tensor:
"""Compute the KDE scores.
The scores calculated from the KDE model are converted to densities. If `as_log_likelihood` is set to true then
Expand All @@ -149,7 +149,7 @@ def evaluate(self, features: Tensor, as_log_likelihood: Optional[bool] = False)
"""

features = self.pca_model.transform(features)
features, _ = self.preprocess(features, self.max_length)
features, _ = self.pre_process(features, self.max_length)
# Scores are always assumed to be passed as a density
kde_scores = self.kde_model(features)

Expand All @@ -161,32 +161,32 @@ def evaluate(self, features: Tensor, as_log_likelihood: Optional[bool] = False)

return kde_scores

def predict(self, features: Tensor) -> Tensor:
"""Predicts the probability that the features belong to the anomalous class.
def compute_probabilities(self, scores: Tensor) -> Tensor:
"""Converts density scores to anomaly probabilities (see https://www.desmos.com/calculator/ifju7eesg7).
Args:
features (Tensor): Feature from which the output probabilities are detected.
scores (Tensor): density of an image.
Returns:
Detection probabilities
probability that image with {density} is anomalous
"""

densities = self.evaluate(features, as_log_likelihood=True)
probabilities = self.to_probability(densities)

return probabilities
return 1 / (1 + torch.exp(self.threshold_steepness * (scores - self.threshold_offset)))

def to_probability(self, densities: Tensor) -> Tensor:
"""Converts density scores to anomaly probabilities (see https://www.desmos.com/calculator/ifju7eesg7).
def predict(self, features: Tensor) -> Tensor:
"""Predicts the probability that the features belong to the anomalous class.
Args:
densities (Tensor): density of an image.
features (Tensor): Feature from which the output probabilities are detected.
Returns:
probability that image with {density} is anomalous
Detection probabilities
"""

return 1 / (1 + torch.exp(self.threshold_steepness * (densities - self.threshold_offset)))
scores = self.compute_kde_scores(features, as_log_likelihood=True)
probabilities = self.compute_probabilities(scores)

return probabilities

def forward(self, batch: Tensor) -> Tensor:
"""Prediction by normality model.
Expand Down
4 changes: 2 additions & 2 deletions anomalib/models/dfm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ metrics:

visualization:
show_images: False # show images on the screen
save_images: True # save images to the file system
log_images: True # log images to the available loggers (if any)
save_images: False # save images to the file system
log_images: False # log images to the available loggers (if any)
image_save_path: null # path to which images will be saved
mode: full # options: ["full", "simple"]

Expand Down

0 comments on commit ddd1b50

Please sign in to comment.