From 3805677bcb28b40e9b51cc46a15979be8ad58c1e Mon Sep 17 00:00:00 2001 From: Samet Akcay Date: Mon, 13 Jun 2022 03:37:37 -0700 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=9A=20Move=20logging=20to=20the=20gett?= =?UTF-8?q?er=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- anomalib/data/__init__.py | 5 +++++ anomalib/models/__init__.py | 5 +++++ anomalib/utils/callbacks/__init__.py | 6 ++++++ anomalib/utils/loggers/__init__.py | 12 ++++++++---- tools/train.py | 13 +++---------- 5 files changed, 27 insertions(+), 14 deletions(-) diff --git a/anomalib/data/__init__.py b/anomalib/data/__init__.py index 9dbba592b8..1e1bcc2235 100644 --- a/anomalib/data/__init__.py +++ b/anomalib/data/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. +import logging from typing import Union from omegaconf import DictConfig, ListConfig @@ -24,6 +25,8 @@ from .inference import InferenceDataset from .mvtec import MVTec +logger = logging.getLogger(__name__) + def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule: """Get Anomaly Datamodule. @@ -34,6 +37,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule Returns: PyTorch Lightning DataModule """ + logger.info("Loading the datamodule") + datamodule: LightningDataModule if config.dataset.format.lower() == "mvtec": diff --git a/anomalib/models/__init__.py b/anomalib/models/__init__.py index 47720f586f..69954bb178 100644 --- a/anomalib/models/__init__.py +++ b/anomalib/models/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. +import logging import os from importlib import import_module from typing import List, Union @@ -23,6 +24,8 @@ from anomalib.models.components import AnomalyModule +logger = logging.getLogger(__name__) + def _snake_to_pascal_case(model_name: str) -> str: """Convert model name from snake case to Pascal case. @@ -54,6 +57,8 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule: Returns: AnomalyModule: Anomaly Model """ + logger.info("Loading the model.") + model_list: List[str] = [ "cflow", "dfkde", diff --git a/anomalib/utils/callbacks/__init__.py b/anomalib/utils/callbacks/__init__.py index 4917f1d3f3..868a4f5018 100644 --- a/anomalib/utils/callbacks/__init__.py +++ b/anomalib/utils/callbacks/__init__.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions # and limitations under the License. +import logging import os import warnings from importlib import import_module @@ -41,6 +42,9 @@ ] +logger = logging.getLogger(__name__) + + def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: """Return base callbacks for all the lightning models. @@ -50,6 +54,8 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]: Return: (List[Callback]): List of callbacks. """ + logger.info("Loading the callbacks") + callbacks: List[Callback] = [] monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric diff --git a/anomalib/utils/loggers/__init__.py b/anomalib/utils/loggers/__init__.py index 5cabc70ef2..ad03e26de5 100644 --- a/anomalib/utils/loggers/__init__.py +++ b/anomalib/utils/loggers/__init__.py @@ -35,6 +35,9 @@ AVAILABLE_LOGGERS = ["tensorboard", "wandb", "csv"] +logger = logging.getLogger(__name__) + + class UnknownLogger(Exception): """This is raised when the logger option in `config.yaml` file is set incorrectly.""" @@ -75,6 +78,7 @@ def get_experiment_logger( Returns: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: Logger """ + logger.info("Loading the experiment logger(s)") # TODO remove when logger is deprecated from project if "logger" in config.project.keys(): @@ -95,8 +99,8 @@ def get_experiment_logger( if isinstance(config.logging.logger, str): config.logging.logger = [config.logging.logger] - for logger in config.logging.logger: - if logger == "tensorboard": + for experiment_logger in config.logging.logger: + if experiment_logger == "tensorboard": logger_list.append( AnomalibTensorBoardLogger( name="Tensorboard Logs", @@ -104,7 +108,7 @@ def get_experiment_logger( log_graph=config.logging.log_graph, ) ) - elif logger == "wandb": + elif experiment_logger == "wandb": wandb_logdir = os.path.join(config.project.path, "logs") os.makedirs(wandb_logdir, exist_ok=True) name = ( @@ -119,7 +123,7 @@ def get_experiment_logger( save_dir=wandb_logdir, ) ) - elif logger == "csv": + elif experiment_logger == "csv": logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs"))) else: raise UnknownLogger( diff --git a/tools/train.py b/tools/train.py index bbe859de90..8f6f8b865e 100644 --- a/tools/train.py +++ b/tools/train.py @@ -64,23 +64,16 @@ def train(): args = get_args() configure_logger(level=args.log_level) + if args.log_level == "ERROR": + warnings.filterwarnings("ignore") + config = get_configurable_parameters(model_name=args.model, config_path=args.config) if config.project.seed != 0: seed_everything(config.project.seed) - if args.log_level == "ERROR": - warnings.filterwarnings("ignore") - - logger.info("Loading the datamodule") datamodule = get_datamodule(config) - - logger.info("Loading the model.") model = get_model(config) - - logger.info("Loading the experiment logger(s)") experiment_logger = get_experiment_logger(config) - - logger.info("Loading the callbacks") callbacks = get_callbacks(config) trainer = Trainer(**config.trainer, logger=experiment_logger, callbacks=callbacks)