Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚚 Move logging from train.py to the getter functions #365

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions anomalib/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
from typing import Union

from omegaconf import DictConfig, ListConfig
Expand All @@ -24,6 +25,8 @@
from .inference import InferenceDataset
from .mvtec import MVTec

logger = logging.getLogger(__name__)


def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule:
"""Get Anomaly Datamodule.
Expand All @@ -34,6 +37,8 @@ def get_datamodule(config: Union[DictConfig, ListConfig]) -> LightningDataModule
Returns:
PyTorch Lightning DataModule
"""
logger.info("Loading the datamodule")

datamodule: LightningDataModule

if config.dataset.format.lower() == "mvtec":
Expand Down
5 changes: 5 additions & 0 deletions anomalib/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
import os
from importlib import import_module
from typing import List, Union
Expand All @@ -23,6 +24,8 @@

from anomalib.models.components import AnomalyModule

logger = logging.getLogger(__name__)


def _snake_to_pascal_case(model_name: str) -> str:
"""Convert model name from snake case to Pascal case.
Expand Down Expand Up @@ -54,6 +57,8 @@ def get_model(config: Union[DictConfig, ListConfig]) -> AnomalyModule:
Returns:
AnomalyModule: Anomaly Model
"""
logger.info("Loading the model.")

model_list: List[str] = [
"cflow",
"dfkde",
Expand Down
6 changes: 6 additions & 0 deletions anomalib/utils/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import logging
import os
import warnings
from importlib import import_module
Expand Down Expand Up @@ -41,6 +42,9 @@
]


logger = logging.getLogger(__name__)


def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
"""Return base callbacks for all the lightning models.

Expand All @@ -50,6 +54,8 @@ def get_callbacks(config: Union[ListConfig, DictConfig]) -> List[Callback]:
Return:
(List[Callback]): List of callbacks.
"""
logger.info("Loading the callbacks")

callbacks: List[Callback] = []

monitor_metric = None if "early_stopping" not in config.model.keys() else config.model.early_stopping.metric
Expand Down
12 changes: 8 additions & 4 deletions anomalib/utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
AVAILABLE_LOGGERS = ["tensorboard", "wandb", "csv"]


logger = logging.getLogger(__name__)


class UnknownLogger(Exception):
"""This is raised when the logger option in `config.yaml` file is set incorrectly."""

Expand Down Expand Up @@ -75,6 +78,7 @@ def get_experiment_logger(
Returns:
Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]: Logger
"""
logger.info("Loading the experiment logger(s)")

# TODO remove when logger is deprecated from project
if "logger" in config.project.keys():
Expand All @@ -95,16 +99,16 @@ def get_experiment_logger(
if isinstance(config.logging.logger, str):
config.logging.logger = [config.logging.logger]

for logger in config.logging.logger:
if logger == "tensorboard":
for experiment_logger in config.logging.logger:
if experiment_logger == "tensorboard":
logger_list.append(
AnomalibTensorBoardLogger(
name="Tensorboard Logs",
save_dir=os.path.join(config.project.path, "logs"),
log_graph=config.logging.log_graph,
)
)
elif logger == "wandb":
elif experiment_logger == "wandb":
wandb_logdir = os.path.join(config.project.path, "logs")
os.makedirs(wandb_logdir, exist_ok=True)
name = (
Expand All @@ -119,7 +123,7 @@ def get_experiment_logger(
save_dir=wandb_logdir,
)
)
elif logger == "csv":
elif experiment_logger == "csv":
logger_list.append(CSVLogger(save_dir=os.path.join(config.project.path, "logs")))
else:
raise UnknownLogger(
Expand Down
13 changes: 3 additions & 10 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,16 @@ def train():
args = get_args()
configure_logger(level=args.log_level)

if args.log_level == "ERROR":
warnings.filterwarnings("ignore")

config = get_configurable_parameters(model_name=args.model, config_path=args.config)
if config.project.seed != 0:
seed_everything(config.project.seed)

if args.log_level == "ERROR":
warnings.filterwarnings("ignore")

logger.info("Loading the datamodule")
datamodule = get_datamodule(config)

logger.info("Loading the model.")
model = get_model(config)

logger.info("Loading the experiment logger(s)")
experiment_logger = get_experiment_logger(config)

logger.info("Loading the callbacks")
callbacks = get_callbacks(config)

trainer = Trainer(**config.trainer, logger=experiment_logger, callbacks=callbacks)
Expand Down