Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor base classes #2164

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/anomalib/callbacks/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from anomalib import TaskType
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
from anomalib.models import AnomalyModule
from anomalib.models import AnomalibModule

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(
def setup(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
stage: str | None = None,
) -> None:
"""Set image and pixel-level AnomalibMetricsCollection within Anomalib Model.
Expand Down Expand Up @@ -87,7 +87,7 @@ def setup(
self.pixel_metric_names if not isinstance(self.pixel_metric_names, str) else [self.pixel_metric_names]
)

if isinstance(pl_module, AnomalyModule):
if isinstance(pl_module, AnomalibModule):
pl_module.image_metrics = create_metric_collection(image_metric_names, "image_")
if hasattr(pl_module, "pixel_metrics"): # incase metrics are loaded from model checkpoint
new_metrics = create_metric_collection(pixel_metric_names)
Expand All @@ -101,7 +101,7 @@ def setup(
def on_validation_epoch_start(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
) -> None:
del trainer # Unused argument.

Expand All @@ -111,7 +111,7 @@ def on_validation_epoch_start(
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -126,7 +126,7 @@ def on_validation_batch_end(
def on_validation_epoch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
) -> None:
del trainer # Unused argument.

Expand All @@ -136,7 +136,7 @@ def on_validation_epoch_end(
def on_test_epoch_start(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
) -> None:
del trainer # Unused argument.

Expand All @@ -146,7 +146,7 @@ def on_test_epoch_start(
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -161,13 +161,13 @@ def on_test_batch_end(
def on_test_epoch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
) -> None:
del trainer # Unused argument.

self._log_metrics(pl_module)

def _set_threshold(self, pl_module: AnomalyModule) -> None:
def _set_threshold(self, pl_module: AnomalibModule) -> None:
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())

Expand All @@ -192,7 +192,7 @@ def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any
return output

@staticmethod
def _log_metrics(pl_module: AnomalyModule) -> None:
def _log_metrics(pl_module: AnomalibModule) -> None:
"""Log computed performance metrics."""
if pl_module.pixel_metrics._update_called: # noqa: SLF001
pl_module.log_dict(pl_module.pixel_metrics, prog_bar=True)
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/callbacks/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from lightning.pytorch import Callback, Trainer

from anomalib.models.components import AnomalyModule
from anomalib.models.components import AnomalibModule

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,7 @@ class LoadModelCallback(Callback):
def __init__(self, weights_path: str) -> None:
self.weights_path = weights_path

def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
def setup(self, trainer: Trainer, pl_module: AnomalibModule, stage: str | None = None) -> None:
"""Call when inference begins.

Loads the model weights from ``weights_path`` into the PyTorch module.
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/callbacks/normalization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from lightning.pytorch import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.models.components import AnomalyModule
from anomalib.models.components import AnomalibModule


class NormalizationCallback(Callback, ABC):
"""Base normalization callback."""

@staticmethod
@abstractmethod
def _normalize_batch(batch: STEP_OUTPUT, pl_module: AnomalyModule) -> None:
def _normalize_batch(batch: STEP_OUTPUT, pl_module: AnomalibModule) -> None:
"""Normalize an output batch.

Args:
Expand Down
14 changes: 7 additions & 7 deletions src/anomalib/callbacks/normalization/min_max_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.metrics import MinMax
from anomalib.models.components import AnomalyModule
from anomalib.models.components import AnomalibModule
from anomalib.utils.normalization.min_max import normalize

from .base import NormalizationCallback
Expand All @@ -22,7 +22,7 @@ class _MinMaxNormalizationCallback(NormalizationCallback):
Note: This callback is set within the Engine.
"""

def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None = None) -> None:
def setup(self, trainer: Trainer, pl_module: AnomalibModule, stage: str | None = None) -> None:
"""Add min_max metrics to normalization metrics."""
del trainer, stage # These variables are not used.

Expand All @@ -34,7 +34,7 @@ def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str | None =
msg,
)

def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
def on_test_start(self, trainer: Trainer, pl_module: AnomalibModule) -> None:
"""Call when the test begins."""
del trainer # `trainer` variable is not used.

Expand All @@ -45,7 +45,7 @@ def on_test_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -67,7 +67,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -81,7 +81,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: Any, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -93,7 +93,7 @@ def on_predict_batch_end(
self._normalize_batch(outputs, pl_module)

@staticmethod
def _normalize_batch(outputs: Any, pl_module: AnomalyModule) -> None: # noqa: ANN401
def _normalize_batch(outputs: Any, pl_module: AnomalibModule) -> None: # noqa: ANN401
"""Normalize a batch of predictions."""
image_threshold = pl_module.image_threshold.value.cpu()
pixel_threshold = pl_module.pixel_threshold.value.cpu()
Expand Down
12 changes: 6 additions & 6 deletions src/anomalib/callbacks/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT

from anomalib.data.utils import boxes_to_anomaly_maps, boxes_to_masks, masks_to_boxes
from anomalib.models import AnomalyModule
from anomalib.models import AnomalibModule


class _PostProcessorCallback(Callback):
Expand All @@ -26,7 +26,7 @@ def __init__(self) -> None:
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -40,7 +40,7 @@ def on_validation_batch_end(
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -54,7 +54,7 @@ def on_test_batch_end(
def on_predict_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: Any, # noqa: ANN401
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -65,15 +65,15 @@ def on_predict_batch_end(
if outputs is not None:
self.post_process(trainer, pl_module, outputs)

def post_process(self, trainer: Trainer, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
def post_process(self, trainer: Trainer, pl_module: AnomalibModule, outputs: STEP_OUTPUT) -> None:
if isinstance(outputs, dict):
self._post_process(outputs)
if trainer.predicting or trainer.testing:
self._compute_scores_and_labels(pl_module, outputs)

@staticmethod
def _compute_scores_and_labels(
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: dict[str, Any],
) -> None:
if "pred_scores" in outputs:
Expand Down
30 changes: 15 additions & 15 deletions src/anomalib/callbacks/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from lightning.pytorch.utilities.types import STEP_OUTPUT
from omegaconf import DictConfig, ListConfig

from anomalib.metrics.threshold import BaseThreshold
from anomalib.models import AnomalyModule
from anomalib.metrics.threshold import Threshold
from anomalib.models import AnomalibModule
from anomalib.utils.types import THRESHOLD


Expand All @@ -28,24 +28,24 @@ def __init__(
) -> None:
super().__init__()
self._initialize_thresholds(threshold)
self.image_threshold: BaseThreshold
self.pixel_threshold: BaseThreshold
self.image_threshold: Threshold
self.pixel_threshold: Threshold

def setup(self, trainer: Trainer, pl_module: AnomalyModule, stage: str) -> None:
def setup(self, trainer: Trainer, pl_module: AnomalibModule, stage: str) -> None:
del trainer, stage # Unused arguments.
if not hasattr(pl_module, "image_threshold"):
pl_module.image_threshold = self.image_threshold
if not hasattr(pl_module, "pixel_threshold"):
pl_module.pixel_threshold = self.pixel_threshold

def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
def on_validation_epoch_start(self, trainer: Trainer, pl_module: AnomalibModule) -> None:
del trainer # Unused argument.
self._reset(pl_module)

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: AnomalyModule,
pl_module: AnomalibModule,
outputs: STEP_OUTPUT | None,
batch: Any, # noqa: ANN401
batch_idx: int,
Expand All @@ -56,7 +56,7 @@ def on_validation_batch_end(
self._outputs_to_cpu(outputs)
self._update(pl_module, outputs)

def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalibModule) -> None:
del trainer # Unused argument.
self._compute(pl_module)

Expand Down Expand Up @@ -86,13 +86,13 @@ def _initialize_thresholds(
# When only a single threshold class is passed.
# This initializes image and pixel thresholds with the same class
# >>> _initialize_thresholds(F1AdaptiveThreshold())
if isinstance(threshold, BaseThreshold):
if isinstance(threshold, Threshold):
self.image_threshold = threshold
self.pixel_threshold = threshold.clone()

# When a tuple of threshold classes are passed
# >>> _initialize_thresholds((ManualThreshold(0.5), ManualThreshold(0.5)))
elif isinstance(threshold, tuple) and isinstance(threshold[0], BaseThreshold):
elif isinstance(threshold, tuple) and isinstance(threshold[0], Threshold):
self.image_threshold = threshold[0]
self.pixel_threshold = threshold[1]
# When the passed threshold is not an instance of a Threshold class.
Expand Down Expand Up @@ -133,7 +133,7 @@ def _load_from_config(self, threshold: DictConfig | str | ListConfig | list[dict
msg = f"Invalid threshold config {threshold}"
raise TypeError(msg)

def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> BaseThreshold:
def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str | float]) -> Threshold:
"""Return the instantiated threshold object.

Example:
Expand All @@ -151,7 +151,7 @@ def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str
>>> __get_threshold_from_config(config)

Returns:
(BaseThreshold): Instance of threshold object.
(Threshold): Instance of threshold object.
"""
if isinstance(threshold, str):
threshold = DictConfig({"class_path": threshold})
Expand All @@ -170,7 +170,7 @@ def _get_threshold_from_config(self, threshold: DictConfig | str | dict[str, str
class_ = getattr(module, class_path)
return class_(**init_args)

def _reset(self, pl_module: AnomalyModule) -> None:
def _reset(self, pl_module: AnomalibModule) -> None:
pl_module.image_threshold.reset()
pl_module.pixel_threshold.reset()

Expand All @@ -182,14 +182,14 @@ def _outputs_to_cpu(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
output = output.cpu()
return output

def _update(self, pl_module: AnomalyModule, outputs: STEP_OUTPUT) -> None:
def _update(self, pl_module: AnomalibModule, outputs: STEP_OUTPUT) -> None:
pl_module.image_threshold.cpu()
pl_module.image_threshold.update(outputs["pred_scores"], outputs["label"].int())
if "mask" in outputs and "anomaly_maps" in outputs:
pl_module.pixel_threshold.cpu()
pl_module.pixel_threshold.update(outputs["anomaly_maps"], outputs["mask"].int())

def _compute(self, pl_module: AnomalyModule) -> None:
def _compute(self, pl_module: AnomalibModule) -> None:
pl_module.image_threshold.compute()
if pl_module.pixel_threshold._update_called: # noqa: SLF001
pl_module.pixel_threshold.compute()
Expand Down
4 changes: 2 additions & 2 deletions src/anomalib/callbacks/tiler_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from lightning.pytorch.callbacks import Callback

from anomalib.data.utils.tiler import ImageUpscaleMode, Tiler
from anomalib.models.components import AnomalyModule
from anomalib.models.components import AnomalibModule

__all__ = ["TilerConfigurationCallback"]

Expand Down Expand Up @@ -61,7 +61,7 @@
del trainer, stage # These variables are not used.

if self.enable:
if isinstance(pl_module, AnomalyModule) and hasattr(pl_module.model, "tiler"):
if isinstance(pl_module, AnomalibModule) and hasattr(pl_module.model, "tiler"):

Check warning on line 64 in src/anomalib/callbacks/tiler_configuration.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/callbacks/tiler_configuration.py#L64

Added line #L64 was not covered by tests
pl_module.model.tiler = Tiler(
tile_size=self.tile_size,
stride=self.stride,
Expand Down
Loading
Loading