Skip to content

Commit

Permalink
Move logger initialization (#750)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuynzereb authored and williamFalcon committed Jan 26, 2020
1 parent cc12ff3 commit 7deec2c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
18 changes: 1 addition & 17 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from abc import ABC

from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.logging import TensorBoardLogger


class TrainerCallbackConfigMixin(ABC):
Expand Down Expand Up @@ -50,7 +49,7 @@ def configure_checkpoint_callback(self):
if self.weights_save_path is None:
self.weights_save_path = self.default_save_path

def configure_early_stopping(self, early_stop_callback, logger):
def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True:
self.early_stop_callback = EarlyStopping(
monitor='val_loss',
Expand All @@ -75,18 +74,3 @@ def configure_early_stopping(self, early_stop_callback, logger):
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True

# configure logger
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_save_path,
version=self.slurm_job_id,
name='lightning_logs'
)
self.logger.rank = 0
elif logger is False:
self.logger = None
else:
self.logger = logger
self.logger.rank = 0
16 changes: 16 additions & 0 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from pytorch_lightning.core import memory
from pytorch_lightning.logging import TensorBoardLogger


class TrainerLoggingMixin(ABC):
Expand All @@ -21,6 +22,21 @@ def __init__(self):
self.use_ddp2 = None
self.num_gpus = None

def configure_logger(self, logger):
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_save_path,
version=self.slurm_job_id,
name='lightning_logs'
)
self.logger.rank = 0
elif logger is False:
self.logger = None
else:
self.logger = logger
self.logger.rank = 0

def log_metrics(self, metrics, grad_norm_dic, step=None):
"""Logs the metric dict passed in.
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,12 @@ def __init__(
self.current_epoch = 0
self.total_batches = 0

# configure logger
self.configure_logger(logger)

# configure early stop callback
# creates a default one if none passed in
self.early_stop_callback = None
self.configure_early_stopping(early_stop_callback, logger)
self.configure_early_stopping(early_stop_callback)

self.reduce_lr_on_plateau_scheduler = None

Expand Down

0 comments on commit 7deec2c

Please sign in to comment.