Skip to content

Commit

Permalink
🚚 Move logging to the getter functions (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
samet-akcay committed Jun 13, 2022
1 parent 028acbf commit 141eb95
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 14 deletions.
5 changes: 5 additions & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
from typing import Union

from omegaconf import DictConfig, ListConfig
Expand All @@ -24,6 +25,8 @@
from .inference import InferenceDataset
from .mvtec import MVTec

logger = logging.getLogger(__name__)


def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule:
"""Get Anomaly Datamodule.
Expand All @@ -34,6 +37,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
Returns:
PyTorch Lightning DataModule
"""
logger.info("Loading the datamodule")

datamodule: LightningDataModule

if config.dataset.format.lower() == "mvtec":
Expand Down
5 changes: 5 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
import os
from importlib import import_module
from typing import List, Union
Expand All @@ -23,6 +24,8 @@

from anomalib.models.components import AnomalyModule

logger = logging.getLogger(__name__)


def _snake_to_pascal_case(model_name: str) -> str:
"""Convert model name from snake case to Pascal case.
Expand Down Expand Up @@ -54,6 +57,8 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
logger.info("Loading the model.")

model_list: List[str] = [
"cflow",
"dfkde",
Expand Down
6 changes: 6 additions & 0 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
import os
import warnings
from importlib import import_module
Expand Down Expand Up @@ -41,6 +42,9 @@
]


logger = logging.getLogger(__name__)


def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
"""Return base callbacks for all the lightning models.
Expand All @@ -50,6 +54,8 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
Return:
(List[Callback]): List of callbacks.
"""
logger.info("Loading the callbacks")

callbacks: List[Callback] = []

monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
Expand Down
12 changes: 8 additions & 4 deletions anomalib/utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
AVAILABLE_LOGGERS = ["tensorboard", "wandb", "csv"]


logger = logging.getLogger(__name__)


class UnknownLogger(Exception):
"""This is raised when the logger option in `config.yaml` file is set incorrectly."""

Expand Down Expand Up @@ -75,6 +78,7 @@ def get_experiment_logger(
Returns:
Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: Logger
"""
logger.info("Loading the experiment logger(s)")

# TODO remove when logger is deprecated from project
if "logger" in config.project.keys():
Expand All @@ -95,16 +99,16 @@ def get_experiment_logger(
if isinstance(config.logging.logger, str):
config.logging.logger = [config.logging.logger]

for logger in config.logging.logger:
if logger == "tensorboard":
for experiment_logger in config.logging.logger:
if experiment_logger == "tensorboard":
logger_list.append(
AnomalibTensorBoardLogger(
name="Tensorboard Logs",
save_dir=os.path.join(config.project.path, "logs"),
log_graph=config.logging.log_graph,
)
)
elif logger == "wandb":
elif experiment_logger == "wandb":
wandb_logdir = os.path.join(config.project.path, "logs")
os.makedirs(wandb_logdir, exist_ok=True)
name = (
Expand All @@ -119,7 +123,7 @@ def get_experiment_logger(
save_dir=wandb_logdir,
)
)
elif logger == "csv":
elif experiment_logger == "csv":
logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
else:
raise UnknownLogger(
Expand Down
13 changes: 3 additions & 10 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,16 @@ def train():
args = get_args()
configure_logger(level=args.log_level)

if args.log_level == "ERROR":
warnings.filterwarnings("ignore")

config = get_configurable_parameters(model_name=args.model, config_path=args.config)
if config.project.seed != 0:
seed_everything(config.project.seed)

if args.log_level == "ERROR":
warnings.filterwarnings("ignore")

logger.info("Loading the datamodule")
datamodule = get_datamodule(config)

logger.info("Loading the model.")
model = get_model(config)

logger.info("Loading the experiment logger(s)")
experiment_logger = get_experiment_logger(config)

logger.info("Loading the callbacks")
callbacks = get_callbacks(config)

trainer = Trainer(**config.trainer, logger=experiment_logger, callbacks=callbacks)
Expand Down

0 comments on commit 141eb95

Please sign in to comment.