diff --git a/anomalib/models/cflow/lightning_model.py b/anomalib/models/cflow/lightning_model.py index 5eaa7ae762..3c715464ce 100644 --- a/anomalib/models/cflow/lightning_model.py +++ b/anomalib/models/cflow/lightning_model.py @@ -67,7 +67,6 @@ def __init__( clamp_alpha=clamp_alpha, permute_soft=permute_soft, ) - self.loss_val = 0 self.automatic_optimization = False def training_step(self, batch, _): # pylint: disable=arguments-differ diff --git a/anomalib/models/fastflow/__init__.py b/anomalib/models/fastflow/__init__.py index 03c02995ca..e5e1d6d6ac 100644 --- a/anomalib/models/fastflow/__init__.py +++ b/anomalib/models/fastflow/__init__.py @@ -5,6 +5,7 @@ # from .lightning_model import Fastflow, FastflowLightning -from .torch_model import FastflowLoss, FastflowModel +from .loss import FastflowLoss +from .torch_model import FastflowModel __all__ = ["FastflowModel", "FastflowLoss", "FastflowLightning", "Fastflow"] diff --git a/anomalib/models/fastflow/lightning_model.py b/anomalib/models/fastflow/lightning_model.py index 5d83084c23..243c433830 100644 --- a/anomalib/models/fastflow/lightning_model.py +++ b/anomalib/models/fastflow/lightning_model.py @@ -14,7 +14,8 @@ from torch import optim from anomalib.models.components import AnomalyModule -from anomalib.models.fastflow.torch_model import FastflowLoss, FastflowModel +from anomalib.models.fastflow.loss import FastflowLoss +from anomalib.models.fastflow.torch_model import FastflowModel logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ def __init__( conv3x3_only=conv3x3_only, hidden_ratio=hidden_ratio, ) - self.loss_func = FastflowLoss() + self.loss = FastflowLoss() def training_step(self, batch, _): # pylint: disable=arguments-differ """Forward-pass input and return the loss. @@ -62,7 +63,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ STEP_OUTPUT: Dictionary containing the loss value. """ hidden_variables, jacobians = self.model(batch["image"]) - loss = self.loss_func(hidden_variables, jacobians) + loss = self.loss(hidden_variables, jacobians) return {"loss": loss} def validation_step(self, batch, _): # pylint: disable=arguments-differ diff --git a/anomalib/models/fastflow/loss.py b/anomalib/models/fastflow/loss.py new file mode 100644 index 0000000000..608b0cfc87 --- /dev/null +++ b/anomalib/models/fastflow/loss.py @@ -0,0 +1,28 @@ +"""Loss function for the FastFlow Model Implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import torch +from torch import Tensor, nn + + +class FastflowLoss(nn.Module): + """FastFlow Loss.""" + + def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Tensor: + """Calculate the Fastflow loss. + + Args: + hidden_variables (List[Tensor]): Hidden variables from the fastflow model. f: X -> Z + jacobians (List[Tensor]): Log of the jacobian determinants from the fastflow model. + + Returns: + Tensor: Fastflow loss computed based on the hidden variables and the log of the Jacobians. + """ + loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable + for (hidden_variable, jacobian) in zip(hidden_variables, jacobians): + loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian) + return loss diff --git a/anomalib/models/fastflow/torch_model.py b/anomalib/models/fastflow/torch_model.py index 298a6b5f96..991f2f9939 100644 --- a/anomalib/models/fastflow/torch_model.py +++ b/anomalib/models/fastflow/torch_model.py @@ -86,25 +86,6 @@ def create_fast_flow_block( return nodes -class FastflowLoss(nn.Module): - """FastFlow Loss.""" - - def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Tensor: - """Calculate the Fastflow loss. - - Args: - hidden_variables (List[Tensor]): Hidden variables from the fastflow model. f: X -> Z - jacobians (List[Tensor]): Log of the jacobian determinants from the fastflow model. - - Returns: - Tensor: _description_ - """ - loss = torch.tensor(0.0, device=hidden_variables[0].device) - for (hidden_variable, jacobian) in zip(hidden_variables, jacobians): - loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian) - return loss - - class FastflowModel(nn.Module): """FastFlow. diff --git a/anomalib/models/ganomaly/config.yaml b/anomalib/models/ganomaly/config.yaml index f443d1eab1..9e56a73d40 100644 --- a/anomalib/models/ganomaly/config.yaml +++ b/anomalib/models/ganomaly/config.yaml @@ -51,7 +51,7 @@ metrics: adaptive: true project: - seed: 0 + seed: 42 path: ./results logging: diff --git a/anomalib/models/ganomaly/lightning_model.py b/anomalib/models/ganomaly/lightning_model.py index 9cab81f40c..f26367fb82 100644 --- a/anomalib/models/ganomaly/lightning_model.py +++ b/anomalib/models/ganomaly/lightning_model.py @@ -26,8 +26,8 @@ from pytorch_lightning.utilities.cli import MODEL_REGISTRY from torch import Tensor, optim -from anomalib.data.utils.image import pad_nextpow2 from anomalib.models.components import AnomalyModule +from anomalib.models.ganomaly.loss import DiscriminatorLoss, GeneratorLoss from .torch_model import GanomalyModel @@ -73,9 +73,6 @@ def __init__( latent_vec_size=latent_vec_size, extra_layers=extra_layers, add_final_conv_layer=add_final_conv_layer, - wadv=wadv, - wcon=wcon, - wenc=wenc, ) self.real_label = torch.ones(size=(batch_size,), dtype=torch.float32) @@ -84,6 +81,9 @@ def __init__( self.min_scores: Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable self.max_scores: Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable + self.generator_loss = GeneratorLoss(wadv, wcon, wenc) + self.discriminator_loss = DiscriminatorLoss() + def _reset_min_max(self): """Resets min_max scores.""" self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable @@ -99,24 +99,18 @@ def training_step(self, batch, _, optimizer_idx): # pylint: disable=arguments-d Returns: Dict[str, Tensor]: Loss """ - images = batch["image"] - padded_images = pad_nextpow2(images) - loss: Dict[str, Tensor] - - # Discriminator - if optimizer_idx == 0: - # forward pass - loss_discriminator = self.model.get_discriminator_loss(padded_images) - loss = {"loss": loss_discriminator} - - # Generator - else: - # forward pass - loss_generator = self.model.get_generator_loss(padded_images) - - loss = {"loss": loss_generator} - - return loss + # forward pass + padded, fake, latent_i, latent_o = self.model(batch["image"]) + pred_real, _ = self.model.discriminator(padded) + + if optimizer_idx == 0: # Discriminator + pred_fake, _ = self.model.discriminator(fake.detach()) + loss = self.discriminator_loss(pred_real, pred_fake) + else: # Generator + pred_fake, _ = self.model.discriminator(fake) + loss = self.generator_loss(latent_i, latent_o, padded, fake, pred_real, pred_fake) + + return {"loss": loss} def on_validation_start(self) -> None: """Reset min and max values for current validation epoch.""" diff --git a/anomalib/models/ganomaly/loss.py b/anomalib/models/ganomaly/loss.py new file mode 100644 index 0000000000..77a6e19fba --- /dev/null +++ b/anomalib/models/ganomaly/loss.py @@ -0,0 +1,79 @@ +"""Loss function for the GANomaly Model Implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import torch +from torch import Tensor, nn + + +class GeneratorLoss(nn.Module): + """Generator loss for the GANomaly model. + + Args: + wadv (int, optional): Weight for adversarial loss. Defaults to 1. + wcon (int, optional): Image regeneration weight. Defaults to 50. + wenc (int, optional): Latent vector encoder weight. Defaults to 1. + """ + + def __init__(self, wadv=1, wcon=50, wenc=1): + super().__init__() + + self.loss_enc = nn.SmoothL1Loss() + self.loss_adv = nn.MSELoss() + self.loss_con = nn.L1Loss() + + self.wadv = wadv + self.wcon = wcon + self.wenc = wenc + + def forward( + self, latent_i: Tensor, latent_o: Tensor, images: Tensor, fake: Tensor, pred_real: Tensor, pred_fake: Tensor + ) -> Tensor: + """Compute the loss for a batch. + + Args: + latent_i (Tensor): Latent features of the first encoder. + latent_o (Tensor): Latent features of the second encoder. + images (Tensor): Real image that served as input of the generator. + fake (Tensor): Generated image. + pred_real (Tensor): Discriminator predictions for the real image. + pred_fake (Tensor): Discriminator predictions for the fake image. + + Returns: + Tensor: The computed generator loss. + """ + error_enc = self.loss_enc(latent_i, latent_o) + error_con = self.loss_con(images, fake) + error_adv = self.loss_adv(pred_real, pred_fake) + + loss = error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc + return loss + + +class DiscriminatorLoss(nn.Module): + """Discriminator loss for the GANomaly model.""" + + def __init__(self): + super().__init__() + + self.loss_bce = nn.BCELoss() + + def forward(self, pred_real, pred_fake): + """Compye the loss for a predicted batch. + + Args: + pred_real (Tensor): Discriminator predictions for the real image. + pred_fake (Tensor): Discriminator predictions for the fake image. + + Returns: + Tensor: The computed discriminator loss. + """ + error_discriminator_real = self.loss_bce( + pred_real, torch.ones(size=pred_real.shape, dtype=torch.float32, device=pred_real.device) + ) + error_discriminator_fake = self.loss_bce( + pred_fake, torch.zeros(size=pred_fake.shape, dtype=torch.float32, device=pred_fake.device) + ) + loss_discriminator = (error_discriminator_fake + error_discriminator_real) * 0.5 + return loss_discriminator diff --git a/anomalib/models/ganomaly/torch_model.py b/anomalib/models/ganomaly/torch_model.py index 88d76d6fc0..ccd273ea54 100644 --- a/anomalib/models/ganomaly/torch_model.py +++ b/anomalib/models/ganomaly/torch_model.py @@ -12,7 +12,7 @@ import math -from typing import Tuple +from typing import Tuple, Union import torch from torch import Tensor, nn @@ -286,9 +286,6 @@ class GanomalyModel(nn.Module): latent_vec_size (int): Size of autoencoder latent vector. extra_layers (int, optional): Number of extra layers for encoder/decoder. Defaults to 0. add_final_conv_layer (bool, optional): Add convolution layer at the end. Defaults to True. - wadv (int, optional): Weight for adversarial loss. Defaults to 1. - wcon (int, optional): Image regeneration weight. Defaults to 50. - wenc (int, optional): Latent vector encoder weight. Defaults to 1. """ def __init__( @@ -299,9 +296,6 @@ def __init__( latent_vec_size: int, extra_layers: int = 0, add_final_conv_layer: bool = True, - wadv: int = 1, - wcon: int = 50, - wenc: int = 1, ) -> None: super().__init__() self.generator: Generator = Generator( @@ -320,13 +314,6 @@ def __init__( ) self.weights_init(self.generator) self.weights_init(self.discriminator) - self.loss_enc = nn.SmoothL1Loss() - self.loss_adv = nn.MSELoss() - self.loss_con = nn.L1Loss() - self.loss_bce = nn.BCELoss() - self.wadv = wadv - self.wcon = wcon - self.wenc = wenc @staticmethod def weights_init(module: nn.Module): @@ -342,51 +329,7 @@ def weights_init(module: nn.Module): nn.init.normal_(module.weight.data, 1.0, 0.02) nn.init.constant_(module.bias.data, 0) - def get_discriminator_loss(self, images: Tensor) -> Tensor: - """Calculates loss for discriminator. - - Args: - images (Tensor): Input images. - - Returns: - Tensor: Discriminator loss. - """ - fake, _, _ = self.generator(images) - pred_real, _ = self.discriminator(images) - pred_fake, _ = self.discriminator(fake.detach()) - - error_discriminator_real = self.loss_bce( - pred_real, torch.ones(size=pred_real.shape, dtype=torch.float32, device=pred_real.device) - ) - error_discriminator_fake = self.loss_bce( - pred_fake, torch.zeros(size=pred_fake.shape, dtype=torch.float32, device=pred_fake.device) - ) - loss_discriminator = (error_discriminator_fake + error_discriminator_real) * 0.5 - return loss_discriminator - - def get_generator_loss(self, images: Tensor) -> Tensor: - """Calculates loss for generator. - - Args: - images (Tensor): Input images. - - Returns: - Tensor: Generator loss. - """ - fake, latent_i, latent_o = self.generator(images) - pred_real, _ = self.discriminator(images) - pred_fake, _ = self.discriminator(fake) - - error_enc = self.loss_enc(latent_i, latent_o) - - error_con = self.loss_con(images, fake) - - error_adv = self.loss_adv(pred_real, pred_fake) - - loss_generator = error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc - return loss_generator - - def forward(self, batch: Tensor) -> Tensor: + def forward(self, batch: Tensor) -> Union[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]: """Get scores for batch. Args: @@ -396,6 +339,7 @@ def forward(self, batch: Tensor) -> Tensor: Tensor: Regeneration scores. """ padded_batch = pad_nextpow2(batch) - self.generator.eval() - _, latent_i, latent_o = self.generator(padded_batch) + fake, latent_i, latent_o = self.generator(padded_batch) + if self.training: + return padded_batch, fake, latent_i, latent_o return torch.mean(torch.pow((latent_i - latent_o), 2), dim=1).view(-1) # convert nx1x1 to n diff --git a/anomalib/models/stfpm/lightning_model.py b/anomalib/models/stfpm/lightning_model.py index f645a23dab..58f8da8950 100644 --- a/anomalib/models/stfpm/lightning_model.py +++ b/anomalib/models/stfpm/lightning_model.py @@ -27,6 +27,7 @@ from torch import optim from anomalib.models.components import AnomalyModule +from anomalib.models.stfpm.loss import STFPMLoss from anomalib.models.stfpm.torch_model import STFPMModel logger = logging.getLogger(__name__) @@ -59,6 +60,7 @@ def __init__( backbone=backbone, layers=layers, ) + self.loss = STFPMLoss() self.loss_val = 0 def training_step(self, batch, _): # pylint: disable=arguments-differ @@ -75,7 +77,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ """ self.model.teacher_model.eval() teacher_features, student_features = self.model.forward(batch["image"]) - loss = self.loss_val + self.model.loss(teacher_features, student_features) + loss = self.loss_val + self.loss(teacher_features, student_features) self.loss_val = 0 return {"loss": loss} diff --git a/anomalib/models/stfpm/loss.py b/anomalib/models/stfpm/loss.py new file mode 100644 index 0000000000..8f60ab2ec8 --- /dev/null +++ b/anomalib/models/stfpm/loss.py @@ -0,0 +1,74 @@ +"""Loss function for the STFPM Model Implementation.""" + +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, List + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + + +class STFPMLoss(nn.Module): + """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. + + Example: + >>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor + >>> from anomalib.models.stfpm.loss import STFPMLoss + >>> from torchvision.models import resnet18 + + >>> layers = ['layer1', 'layer2', 'layer3'] + >>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) + >>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) + >>> loss = Loss() + + >>> inp = torch.rand((4, 3, 256, 256)) + >>> teacher_features = teacher_model(inp) + >>> student_features = student_model(inp) + >>> loss(student_features, teacher_features) + tensor(51.2015, grad_fn=) + """ + + def __init__(self): + super().__init__() + self.mse_loss = nn.MSELoss(reduction="sum") + + def compute_layer_loss(self, teacher_feats: Tensor, student_feats: Tensor) -> Tensor: + """Compute layer loss based on Equation (1) in Section 3.2 of the paper. + + Args: + teacher_feats (Tensor): Teacher features + student_feats (Tensor): Student features + + Returns: + L2 distance between teacher and student features. + """ + + height, width = teacher_feats.shape[2:] + + norm_teacher_features = F.normalize(teacher_feats) + norm_student_features = F.normalize(student_feats) + layer_loss = (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) + + return layer_loss + + def forward(self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor]) -> Tensor: + """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. + + Args: + teacher_features (Dict[str, Tensor]): Teacher features + student_features (Dict[str, Tensor]): Student features + + Returns: + Total loss, which is the weighted average of the layer losses. + """ + + layer_losses: List[Tensor] = [] + for layer in teacher_features.keys(): + loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) + layer_losses.append(loss) + + total_loss = torch.stack(layer_losses).sum() + + return total_loss diff --git a/anomalib/models/stfpm/torch_model.py b/anomalib/models/stfpm/torch_model.py index 9b3973c2a9..ca3485812f 100644 --- a/anomalib/models/stfpm/torch_model.py +++ b/anomalib/models/stfpm/torch_model.py @@ -16,8 +16,6 @@ from typing import Dict, List, Optional, Tuple -import torch -import torch.nn.functional as F import torchvision from torch import Tensor, nn @@ -26,70 +24,6 @@ from anomalib.pre_processing import Tiler -class Loss(nn.Module): - """Feature Pyramid Loss This class implmenents the feature pyramid loss function proposed in STFPM paper. - - Example: - >>> from anomalib.models.components.feature_extractors.feature_extractor import FeatureExtractor - >>> from anomalib.models.stfpm.torch_model import Loss - >>> from torchvision.models import resnet18 - - >>> layers = ['layer1', 'layer2', 'layer3'] - >>> teacher_model = FeatureExtractor(model=resnet18(pretrained=True), layers=layers) - >>> student_model = FeatureExtractor(model=resnet18(pretrained=False), layers=layers) - >>> loss = Loss() - - >>> inp = torch.rand((4, 3, 256, 256)) - >>> teacher_features = teacher_model(inp) - >>> student_features = student_model(inp) - >>> loss(student_features, teacher_features) - tensor(51.2015, grad_fn=) - """ - - def __init__(self): - super().__init__() - self.mse_loss = nn.MSELoss(reduction="sum") - - def compute_layer_loss(self, teacher_feats: Tensor, student_feats: Tensor) -> Tensor: - """Compute layer loss based on Equation (1) in Section 3.2 of the paper. - - Args: - teacher_feats (Tensor): Teacher features - student_feats (Tensor): Student features - - Returns: - L2 distance between teacher and student features. - """ - - height, width = teacher_feats.shape[2:] - - norm_teacher_features = F.normalize(teacher_feats) - norm_student_features = F.normalize(student_feats) - layer_loss = (0.5 / (width * height)) * self.mse_loss(norm_teacher_features, norm_student_features) - - return layer_loss - - def forward(self, teacher_features: Dict[str, Tensor], student_features: Dict[str, Tensor]) -> Tensor: - """Compute the overall loss via the weighted average of the layer losses computed by the cosine similarity. - - Args: - teacher_features (Dict[str, Tensor]): Teacher features - student_features (Dict[str, Tensor]): Student features - - Returns: - Total loss, which is the weighted average of the layer losses. - """ - - layer_losses: List[Tensor] = [] - for layer in teacher_features.keys(): - loss = self.compute_layer_loss(teacher_features[layer], student_features[layer]) - layer_losses.append(loss) - - total_loss = torch.stack(layer_losses).sum() - - return total_loss - - class STFPMModel(nn.Module): """STFPM: Student-Teacher Feature Pyramid Matching for Unsupervised Anomaly Detection. @@ -116,8 +50,6 @@ def __init__( for parameters in self.teacher_model.parameters(): parameters.requires_grad = False - self.loss = Loss() - # Create the anomaly heatmap generator whether tiling is set. # TODO: Check whether Tiler is properly initialized here. if self.tiler: