Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐞 Fix image loggers #233

Merged
merged 2 commits into from
Apr 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions anomalib/utils/callbacks/visualizer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# and limitations under the License.

from pathlib import Path
from typing import Any, Optional
from typing import Any, Optional, cast
from warnings import warn

import pytorch_lightning as pl
Expand Down Expand Up @@ -50,6 +50,7 @@ def _add_images(
self,
visualizer: Visualizer,
module: AnomalyModule,
trainer: pl.Trainer,
filename: Path,
):
"""Save image to logger/local storage.
Expand All @@ -59,19 +60,21 @@ def _add_images(

Args:
visualizer (Visualizer): Visualizer object from which the `figure` is saved/logged.
module (AnomalyModule): Anomaly module which holds reference to `hparams` and `logger`.
module (AnomalyModule): Anomaly module which holds reference to `hparams`.
trainer (Trainer): Pytorch Lightning trainer which holds reference to `logger`
filename (Path): Path of the input image. This name is used as name for the generated image.
"""

# store current logger type as a string
logger_type = type(module.logger).__name__.lower()

# Store names of logger and the logger in a dict
available_loggers = {
type(logger).__name__.lower().rstrip("logger").lstrip("anomalib"): logger for logger in trainer.loggers
}
# save image to respective logger
for log_to in module.hparams.project.log_images_to:
if log_to in loggers.AVAILABLE_LOGGERS:
# check if logger object is same as the requested object
if log_to in logger_type and module.logger is not None and isinstance(module.logger, ImageLoggerBase):
module.logger.add_image(
if log_to in available_loggers and isinstance(available_loggers[log_to], ImageLoggerBase):
logger: ImageLoggerBase = cast(ImageLoggerBase, available_loggers[log_to]) # placate mypy
logger.add_image(
image=visualizer.figure,
name=filename.parent.name + "_" + filename.name,
global_step=module.global_step,
Expand All @@ -81,13 +84,15 @@ def _add_images(
f"Requested {log_to} logging but logger object is of type: {type(module.logger)}."
f" Skipping logging to {log_to}"
)
else:
warn(f"{log_to} not in the list of supported image loggers.")

if "local" in module.hparams.project.log_images_to:
visualizer.save(Path(module.hparams.project.path) / "images" / filename.parent.name / filename.name)

def on_test_batch_end(
self,
_trainer: pl.Trainer,
trainer: pl.Trainer,
pl_module: AnomalyModule,
outputs: Optional[STEP_OUTPUT],
_batch: Any,
Expand All @@ -97,7 +102,7 @@ def on_test_batch_end(
"""Log images at the end of every batch.

Args:
_trainer (Trainer): Pytorch lightning trainer object (unused).
trainer (Trainer): Pytorch lightning trainer object (unused).
pl_module (LightningModule): Lightning modules derived from BaseAnomalyLightning object as
currently only they support logging images.
outputs (Dict[str, Any]): Outputs of the current test step.
Expand Down Expand Up @@ -147,7 +152,7 @@ def on_test_batch_end(
)
visualizer.add_image(image=image_classified, title="Classified Image")

self._add_images(visualizer, pl_module, Path(filename))
self._add_images(visualizer, pl_module, trainer, Path(filename))
visualizer.close()

def on_test_end(self, _trainer: pl.Trainer, pl_module: AnomalyModule) -> None:
Expand Down
2 changes: 0 additions & 2 deletions anomalib/utils/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def __init__(
log_model: Union[str, bool] = False,
experiment=None,
prefix: Optional[str] = "",
sync_step: Optional[bool] = None,
**kwargs
) -> None:
super().__init__(
Expand All @@ -103,7 +102,6 @@ def __init__(
log_model=log_model,
experiment=experiment,
prefix=prefix,
sync_step=sync_step,
**kwargs
)
self.image_list: List[wandb.Image] = [] # Cache images
Expand Down