Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚜 Refactor loss computation #364

Merged
merged 6 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion anomalib/models/cflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(
clamp_alpha=clamp_alpha,
permute_soft=permute_soft,
)
self.loss_val = 0
self.automatic_optimization = False

def training_step(self, batch, _): # pylint: disable=arguments-differ
Expand Down
3 changes: 2 additions & 1 deletion anomalib/models/fastflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#

from .lightning_model import Fastflow, FastflowLightning
from .torch_model import FastflowLoss, FastflowModel
from .loss import FastflowLoss
from .torch_model import FastflowModel

__all__ = ["FastflowModel", "FastflowLoss", "FastflowLightning", "Fastflow"]
7 changes: 4 additions & 3 deletions anomalib/models/fastflow/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch import optim

from anomalib.models.components import AnomalyModule
from anomalib.models.fastflow.torch_model import FastflowLoss, FastflowModel
from anomalib.models.fastflow.loss import FastflowLoss
from anomalib.models.fastflow.torch_model import FastflowModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
conv3x3_only=conv3x3_only,
hidden_ratio=hidden_ratio,
)
self.loss_func = FastflowLoss()
self.loss = FastflowLoss()

def training_step(self, batch, _): # pylint: disable=arguments-differ
"""Forward-pass input and return the loss.
Expand All @@ -62,7 +63,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
STEP_OUTPUT: Dictionary containing the loss value.
"""
hidden_variables, jacobians = self.model(batch["image"])
loss = self.loss_func(hidden_variables, jacobians)
loss = self.loss(hidden_variables, jacobians)
return {"loss": loss}

def validation_step(self, batch, _): # pylint: disable=arguments-differ
Expand Down
28 changes: 28 additions & 0 deletions anomalib/models/fastflow/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Loss function for the FastFlow Model Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from typing import List

import torch
from torch import Tensor, nn


class FastflowLoss(nn.Module):
"""FastFlow Loss."""

def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Tensor:
"""Calculate the Fastflow loss.

Args:
hidden_variables (List[Tensor]): Hidden variables from the fastflow model. f: X -> Z
jacobians (List[Tensor]): Log of the jacobian determinants from the fastflow model.

Returns:
Tensor: _description_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a leftover from the previous PR, but would be good to add a description here in this PR.

"""
loss = torch.tensor(0.0, device=hidden_variables[0].device) # pylint: disable=not-callable
for (hidden_variable, jacobian) in zip(hidden_variables, jacobians):
loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian)
return loss
19 changes: 0 additions & 19 deletions anomalib/models/fastflow/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,25 +86,6 @@ def create_fast_flow_block(
return nodes


class FastflowLoss(nn.Module):
"""FastFlow Loss."""

def forward(self, hidden_variables: List[Tensor], jacobians: List[Tensor]) -> Tensor:
"""Calculate the Fastflow loss.

Args:
hidden_variables (List[Tensor]): Hidden variables from the fastflow model. f: X -> Z
jacobians (List[Tensor]): Log of the jacobian determinants from the fastflow model.

Returns:
Tensor: _description_
"""
loss = torch.tensor(0.0, device=hidden_variables[0].device)
for (hidden_variable, jacobian) in zip(hidden_variables, jacobians):
loss += torch.mean(0.5 * torch.sum(hidden_variable**2, dim=(1, 2, 3)) - jacobian)
return loss


class FastflowModel(nn.Module):
"""FastFlow.

Expand Down
2 changes: 1 addition & 1 deletion anomalib/models/ganomaly/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ metrics:
adaptive: true

project:
seed: 0
seed: 42
path: ./results

logging:
Expand Down
38 changes: 16 additions & 22 deletions anomalib/models/ganomaly/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from pytorch_lightning.utilities.cli import MODEL_REGISTRY
from torch import Tensor, optim

from anomalib.data.utils.image import pad_nextpow2
from anomalib.models.components import AnomalyModule
from anomalib.models.ganomaly.loss import DiscriminatorLoss, GeneratorLoss

from .torch_model import GanomalyModel

Expand Down Expand Up @@ -73,9 +73,6 @@ def __init__(
latent_vec_size=latent_vec_size,
extra_layers=extra_layers,
add_final_conv_layer=add_final_conv_layer,
wadv=wadv,
wcon=wcon,
wenc=wenc,
)

self.real_label = torch.ones(size=(batch_size,), dtype=torch.float32)
Expand All @@ -84,6 +81,9 @@ def __init__(
self.min_scores: Tensor = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
self.max_scores: Tensor = torch.tensor(float("-inf"), dtype=torch.float32) # pylint: disable=not-callable

self.generator_loss = GeneratorLoss(wadv, wcon, wenc)
self.discriminator_loss = DiscriminatorLoss()

def _reset_min_max(self):
"""Resets min_max scores."""
self.min_scores = torch.tensor(float("inf"), dtype=torch.float32) # pylint: disable=not-callable
Expand All @@ -99,24 +99,18 @@ def training_step(self, batch, _, optimizer_idx): # pylint: disable=arguments-d
Returns:
Dict[str, Tensor]: Loss
"""
images = batch["image"]
padded_images = pad_nextpow2(images)
loss: Dict[str, Tensor]

# Discriminator
if optimizer_idx == 0:
# forward pass
loss_discriminator = self.model.get_discriminator_loss(padded_images)
loss = {"loss": loss_discriminator}

# Generator
else:
# forward pass
loss_generator = self.model.get_generator_loss(padded_images)

loss = {"loss": loss_generator}

return loss
# forward pass
padded, fake, latent_i, latent_o = self.model(batch["image"])
pred_real, _ = self.model.discriminator(padded)

if optimizer_idx == 0: # Discriminator
pred_fake, _ = self.model.discriminator(fake.detach())
loss = self.discriminator_loss(pred_real, pred_fake)
else: # Generator
pred_fake, _ = self.model.discriminator(fake)
loss = self.generator_loss(latent_i, latent_o, padded, fake, pred_real, pred_fake)

return {"loss": loss}

def on_validation_start(self) -> None:
"""Reset min and max values for current validation epoch."""
Expand Down
79 changes: 79 additions & 0 deletions anomalib/models/ganomaly/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Loss function for the GANomaly Model Implementation."""

# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import torch
from torch import Tensor, nn


class GeneratorLoss(nn.Module):
"""Generator loss for the GANomaly model.

Args:
wadv (int, optional): Weight for adversarial loss. Defaults to 1.
wcon (int, optional): Image regeneration weight. Defaults to 50.
wenc (int, optional): Latent vector encoder weight. Defaults to 1.
"""

def __init__(self, wadv=1, wcon=50, wenc=1):
super().__init__()

self.loss_enc = nn.SmoothL1Loss()
self.loss_adv = nn.MSELoss()
self.loss_con = nn.L1Loss()

self.wadv = wadv
self.wcon = wcon
self.wenc = wenc

def forward(
self, latent_i: Tensor, latent_o: Tensor, images: Tensor, fake: Tensor, pred_real: Tensor, pred_fake: Tensor
) -> Tensor:
"""Compute the loss for a batch.

Args:
latent_i (Tensor): Latent features of the first encoder.
latent_o (Tensor): Latent features of the second encoder.
images (Tensor): Real image that served as input of the generator.
fake (Tensor): Generated image.
pred_real (Tensor): Discriminator predictions for the real image.
pred_fake (Tensor): Discriminator predictions for the fake image.

Returns:
Tensor: The computed generator loss.
"""
error_enc = self.loss_enc(latent_i, latent_o)
error_con = self.loss_con(images, fake)
error_adv = self.loss_adv(pred_real, pred_fake)

loss = error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc
return loss


class DiscriminatorLoss(nn.Module):
"""Discriminator loss for the GANomaly model."""

def __init__(self):
super().__init__()

self.loss_bce = nn.BCELoss()

def forward(self, pred_real, pred_fake):
"""Compye the loss for a predicted batch.

Args:
pred_real (Tensor): Discriminator predictions for the real image.
pred_fake (Tensor): Discriminator predictions for the fake image.

Returns:
Tensor: The computed discriminator loss.
"""
error_discriminator_real = self.loss_bce(
pred_real, torch.ones(size=pred_real.shape, dtype=torch.float32, device=pred_real.device)
)
error_discriminator_fake = self.loss_bce(
pred_fake, torch.zeros(size=pred_fake.shape, dtype=torch.float32, device=pred_fake.device)
)
loss_discriminator = (error_discriminator_fake + error_discriminator_real) * 0.5
return loss_discriminator
66 changes: 5 additions & 61 deletions anomalib/models/ganomaly/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


import math
from typing import Tuple
from typing import Tuple, Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -286,9 +286,6 @@ class GanomalyModel(nn.Module):
latent_vec_size (int): Size of autoencoder latent vector.
extra_layers (int, optional): Number of extra layers for encoder/decoder. Defaults to 0.
add_final_conv_layer (bool, optional): Add convolution layer at the end. Defaults to True.
wadv (int, optional): Weight for adversarial loss. Defaults to 1.
wcon (int, optional): Image regeneration weight. Defaults to 50.
wenc (int, optional): Latent vector encoder weight. Defaults to 1.
"""

def __init__(
Expand All @@ -299,9 +296,6 @@ def __init__(
latent_vec_size: int,
extra_layers: int = 0,
add_final_conv_layer: bool = True,
wadv: int = 1,
wcon: int = 50,
wenc: int = 1,
) -> None:
super().__init__()
self.generator: Generator = Generator(
Expand All @@ -320,13 +314,6 @@ def __init__(
)
self.weights_init(self.generator)
self.weights_init(self.discriminator)
self.loss_enc = nn.SmoothL1Loss()
self.loss_adv = nn.MSELoss()
self.loss_con = nn.L1Loss()
self.loss_bce = nn.BCELoss()
self.wadv = wadv
self.wcon = wcon
self.wenc = wenc

@staticmethod
def weights_init(module: nn.Module):
Expand All @@ -342,51 +329,7 @@ def weights_init(module: nn.Module):
nn.init.normal_(module.weight.data, 1.0, 0.02)
nn.init.constant_(module.bias.data, 0)

def get_discriminator_loss(self, images: Tensor) -> Tensor:
"""Calculates loss for discriminator.

Args:
images (Tensor): Input images.

Returns:
Tensor: Discriminator loss.
"""
fake, _, _ = self.generator(images)
pred_real, _ = self.discriminator(images)
pred_fake, _ = self.discriminator(fake.detach())

error_discriminator_real = self.loss_bce(
pred_real, torch.ones(size=pred_real.shape, dtype=torch.float32, device=pred_real.device)
)
error_discriminator_fake = self.loss_bce(
pred_fake, torch.zeros(size=pred_fake.shape, dtype=torch.float32, device=pred_fake.device)
)
loss_discriminator = (error_discriminator_fake + error_discriminator_real) * 0.5
return loss_discriminator

def get_generator_loss(self, images: Tensor) -> Tensor:
"""Calculates loss for generator.

Args:
images (Tensor): Input images.

Returns:
Tensor: Generator loss.
"""
fake, latent_i, latent_o = self.generator(images)
pred_real, _ = self.discriminator(images)
pred_fake, _ = self.discriminator(fake)

error_enc = self.loss_enc(latent_i, latent_o)

error_con = self.loss_con(images, fake)

error_adv = self.loss_adv(pred_real, pred_fake)

loss_generator = error_adv * self.wadv + error_con * self.wcon + error_enc * self.wenc
return loss_generator

def forward(self, batch: Tensor) -> Tensor:
def forward(self, batch: Tensor) -> Union[Tuple[Tensor, Tensor, Tensor, Tensor], Tensor]:
"""Get scores for batch.

Args:
Expand All @@ -396,6 +339,7 @@ def forward(self, batch: Tensor) -> Tensor:
Tensor: Regeneration scores.
"""
padded_batch = pad_nextpow2(batch)
self.generator.eval()
_, latent_i, latent_o = self.generator(padded_batch)
fake, latent_i, latent_o = self.generator(padded_batch)
if self.training:
return padded_batch, fake, latent_i, latent_o
return torch.mean(torch.pow((latent_i - latent_o), 2), dim=1).view(-1) # convert nx1x1 to n
4 changes: 3 additions & 1 deletion anomalib/models/stfpm/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import optim

from anomalib.models.components import AnomalyModule
from anomalib.models.stfpm.loss import STFPMLoss
from anomalib.models.stfpm.torch_model import STFPMModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
backbone=backbone,
layers=layers,
)
self.loss = STFPMLoss()
self.loss_val = 0

def training_step(self, batch, _): # pylint: disable=arguments-differ
Expand All @@ -75,7 +77,7 @@ def training_step(self, batch, _): # pylint: disable=arguments-differ
"""
self.model.teacher_model.eval()
teacher_features, student_features = self.model.forward(batch["image"])
loss = self.loss_val + self.model.loss(teacher_features, student_features)
loss = self.loss_val + self.loss(teacher_features, student_features)
self.loss_val = 0
return {"loss": loss}

Expand Down
Loading