Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple loggers #903

Merged
merged 14 commits into from
Feb 25, 2020
2 changes: 1 addition & 1 deletion pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def any_lightning_module_function_or_hook(...):
"""
from os import environ

from .base import LightningLoggerBase, rank_zero_only
from .base import LightningLoggerBase, LightningLoggerList, rank_zero_only
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
from .tensorboard import TensorBoardLogger

__all__ = ['TensorBoardLogger']
Expand Down
50 changes: 49 additions & 1 deletion pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def experiment(self):
def log_metrics(self, metrics, step):
"""Record metrics.

:param float metric: Dictionary with metric names as keys and measured quanties as values
:param float metrics: Dictionary with metric names as keys and measured quantities as values
:param int|None step: Step number at which the metrics should be recorded
"""
raise NotImplementedError()
Expand Down Expand Up @@ -72,3 +72,51 @@ def name(self):
def version(self):
"""Return the experiment version."""
raise NotImplementedError("Sub-classes must provide a version property")


class LightningLoggerList(LightningLoggerBase):
Borda marked this conversation as resolved.
Show resolved Hide resolved
"""The `LoggerList` class is used to iterate all logging actions over the given `logger_list`.

:param logger_list: An iterable collection of loggers
"""

def __init__(self, logger_list):
super().__init__()
self._logger_list = logger_list

@property
def experiment(self):
return [logger.experiment() for logger in self._logger_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this may be tricky since it returns list instead object...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is really just a precaution anyway as we don't use the experiment property internally. However, the experiment property is a different object for every logger so I think this is acceptable

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know, but we do not know about all user usage of these loggers lol

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure there's any other option here, any suggestions?


def log_metrics(self, metrics, step):
return [logger.log_metrics(metrics, step) for logger in self._logger_list]

def log_hyperparams(self, params):
return [logger.log_hyperparams(params) for logger in self._logger_list]

def save(self):
return [logger.save() for logger in self._logger_list]

def finalize(self, status):
return [logger.finalize(status) for logger in self._logger_list]

def close(self):
return [logger.close() for logger in self._logger_list]

@property
def rank(self):
return self._rank

@rank.setter
def rank(self, value):
self._rank = value
for logger in self._logger_list:
logger.rank = value

@property
def name(self):
return '_'.join([str(logger.name) for logger in self._logger_list])
Borda marked this conversation as resolved.
Show resolved Hide resolved

@property
def version(self):
return '_'.join([str(logger.version) for logger in self._logger_list])
10 changes: 8 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerList


class TrainerLoggingMixin(ABC):
Expand Down Expand Up @@ -34,7 +34,13 @@ def configure_logger(self, logger):
elif logger is False:
self.logger = None
else:
self.logger = logger
try:
_ = iter(logger)
Borda marked this conversation as resolved.
Show resolved Hide resolved
# can call iter on logger, make it a logger list
self.logger = LightningLoggerList(logger)
except TypeError:
# can't call iter, must just be a regular logger
self.logger = logger
self.logger.rank = 0

def log_metrics(self, metrics, grad_norm_dic, step=None):
Expand Down
89 changes: 67 additions & 22 deletions tests/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,34 +330,35 @@ def test_tensorboard_log_hyperparams(tmpdir):
logger.log_hyperparams(hparams)


def test_custom_logger(tmpdir):
class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.hparams_logged = None
self.metrics_logged = None
self.finalized = False
class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.hparams_logged = None
self.metrics_logged = None
self.finalized = False

@rank_zero_only
def log_hyperparams(self, params):
self.hparams_logged = params

@rank_zero_only
def log_hyperparams(self, params):
self.hparams_logged = params
@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics_logged = metrics

@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics_logged = metrics
@rank_zero_only
def finalize(self, status):
self.finalized_status = status

@rank_zero_only
def finalize(self, status):
self.finalized_status = status
@property
def name(self):
return "name"

@property
def name(self):
return "name"
@property
def version(self):
return "1"

@property
def version(self):
return "1"

def test_custom_logger(tmpdir):
hparams = tutils.get_hparams()
model = LightningTestModel(hparams)

Expand All @@ -378,6 +379,50 @@ def version(self):
assert logger.finalized_status == "success"


def test_multiple_loggers(tmpdir):
hparams = tutils.get_hparams()
model = LightningTestModel(hparams)

logger1 = CustomLogger()
logger2 = CustomLogger()

trainer_options = dict(
max_epochs=1,
train_percent_check=0.05,
logger=[logger1, logger2],
default_save_path=tmpdir
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)
assert result == 1, "Training failed"

assert logger1.hparams_logged == hparams
assert logger1.metrics_logged != {}
assert logger1.finalized_status == "success"

assert logger2.hparams_logged == hparams
assert logger2.metrics_logged != {}
assert logger2.finalized_status == "success"


def test_multiple_loggers_pickle(tmpdir):
"""Verify that pickling trainer with multiple loggers works."""

logger1 = CustomLogger()
logger2 = CustomLogger()

trainer_options = dict(max_epochs=1, logger=[logger1, logger2])

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0}, 0)

assert logger1.metrics_logged != {}
assert logger2.metrics_logged != {}


def test_adding_step_key(tmpdir):
logged_step = 0

Expand Down