Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for multiple loggers #903

Merged
merged 14 commits into from
Feb 25, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868))
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))

### Changed

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def any_lightning_module_function_or_hook(...):
"""
from os import environ

from .base import LightningLoggerBase, rank_zero_only
from .base import LightningLoggerBase, LoggerCollection, rank_zero_only
from .tensorboard import TensorBoardLogger

__all__ = ['TensorBoardLogger']
__all__ = ['TensorBoardLogger', 'LoggerCollection']

try:
# needed to prevent ImportError and duplicated logs.
Expand Down
83 changes: 69 additions & 14 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import argparse
from abc import ABC
from functools import wraps
from typing import Union, Optional, Dict, Iterable, Any, Callable, List


def rank_zero_only(fn):
def rank_zero_only(fn: Callable):
"""Decorate a logger method to run it only on the process with rank 0.

:param fn: Function to decorate
Args:
fn: Function to decorate
"""

@wraps(fn)
Expand All @@ -23,52 +26,104 @@ def __init__(self):
self._rank = 0

@property
def experiment(self):
def experiment(self) -> Any:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError()

def log_metrics(self, metrics, step):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
"""Record metrics.

:param float metric: Dictionary with metric names as keys and measured quanties as values
:param int|None step: Step number at which the metrics should be recorded
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
raise NotImplementedError()

def log_hyperparams(self, params):
def log_hyperparams(self, params: argparse.Namespace):
"""Record hyperparameters.

:param params: argparse.Namespace containing the hyperparameters
Args:
params: argparse.Namespace containing the hyperparameters
"""
raise NotImplementedError()

def save(self):
"""Save log data."""

def finalize(self, status):
def finalize(self, status: str):
"""Do any processing that is necessary to finalize an experiment.

:param status: Status that the experiment finished with (e.g. success, failed, aborted)
Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""

def close(self):
"""Do any cleanup that is necessary to close an experiment."""

@property
def rank(self):
def rank(self) -> int:
"""Process rank. In general, metrics should only be logged by the process with rank 0."""
return self._rank

@rank.setter
def rank(self, value):
def rank(self, value: int):
"""Set the process rank."""
self._rank = value

@property
def name(self):
def name(self) -> str:
"""Return the experiment name."""
raise NotImplementedError("Sub-classes must provide a name property")

@property
def version(self):
def version(self) -> Union[int, str]:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
"""Return the experiment version."""
raise NotImplementedError("Sub-classes must provide a version property")


class LoggerCollection(LightningLoggerBase):
"""The `LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`.

Args:
logger_iterable: An iterable collection of loggers
"""

def __init__(self, logger_iterable: Iterable[LightningLoggerBase]):
super().__init__()
self._logger_iterable = logger_iterable

@property
def experiment(self) -> List[Any]:
return [logger.experiment() for logger in self._logger_iterable]

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]

def log_hyperparams(self, params: argparse.Namespace):
[logger.log_hyperparams(params) for logger in self._logger_iterable]

def save(self):
[logger.save() for logger in self._logger_iterable]

def finalize(self, status: str):
[logger.finalize(status) for logger in self._logger_iterable]

def close(self):
[logger.close() for logger in self._logger_iterable]

@property
def rank(self) -> int:
return self._rank

@rank.setter
def rank(self, value: int):
self._rank = value
for logger in self._logger_iterable:
logger.rank = value

@property
def name(self) -> str:
return '_'.join([str(logger.name) for logger in self._logger_iterable])

@property
def version(self) -> str:
return '_'.join([str(logger.version) for logger in self._logger_iterable])
28 changes: 15 additions & 13 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
CometLogger
-------------
"""

import argparse
from logging import getLogger
from typing import Optional, Union, Dict

try:
from comet_ml import Experiment as CometExperiment
Expand All @@ -33,9 +34,10 @@ class CometLogger(LightningLoggerBase):
Log using `comet.ml <https://www.comet.ml>`_.
"""

def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None,
experiment_key=None, **kwargs):
def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None,
workspace: Optional[str] = None, project_name: Optional[str] = None,
rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None, **kwargs):
r"""

Requires either an API Key (online mode) or a local directory path (offline mode)
Expand Down Expand Up @@ -77,8 +79,8 @@ def __init__(self, api_key=None, save_dir=None, workspace=None,
If project name does not already exists Comet.ml will create a new project.
rest_api_key (str): Optional. Rest API key found in Comet.ml settings.
This is used to determine version number
experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml

experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml.
experiment_key (str): Optional. If set, restores from existing experiment.
"""
super().__init__()
self._experiment = None
Expand Down Expand Up @@ -120,7 +122,7 @@ def __init__(self, api_key=None, save_dir=None, workspace=None,
logger.exception("Failed to set experiment name for comet.ml logger")

@property
def experiment(self):
def experiment(self) -> Union[CometOfflineExperiment, CometExistingExperiment, CometExperiment]:
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
r"""

Actual comet object. To use comet features do the following.
Expand Down Expand Up @@ -161,11 +163,11 @@ def experiment(self):
return self._experiment

@rank_zero_only
def log_hyperparams(self, params):
def log_hyperparams(self, params: argparse.Namespace):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step=None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand All @@ -177,7 +179,7 @@ def reset_experiment(self):
self._experiment = None

@rank_zero_only
def finalize(self, status):
def finalize(self, status: str):
r"""
When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need
to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your
Expand All @@ -190,13 +192,13 @@ def finalize(self, status):
self.reset_experiment()

@property
def name(self):
def name(self) -> str:
return self.experiment.project_name

@name.setter
def name(self, value):
def name(self, value: str):
self.experiment.set_name(value)

@property
def version(self):
def version(self) -> str:
return self.experiment.id
18 changes: 10 additions & 8 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def any_lightning_module_function_or_hook(...):
self.logger.experiment.whatever_ml_flow_supports(...)

"""

import argparse
from logging import getLogger
from time import time
from typing import Optional, Dict, Any

try:
import mlflow
Expand All @@ -38,7 +39,8 @@ def any_lightning_module_function_or_hook(...):


class MLFlowLogger(LightningLoggerBase):
def __init__(self, experiment_name, tracking_uri=None, tags=None):
def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None,
tags: Dict[str, Any] = None):
r"""

Logs using MLFlow
Expand All @@ -55,7 +57,7 @@ def __init__(self, experiment_name, tracking_uri=None, tags=None):
self.tags = tags

@property
def experiment(self):
def experiment(self) -> mlflow.tracking.MlflowClient:
r"""

Actual mlflow object. To use mlflow features do the following.
Expand Down Expand Up @@ -85,12 +87,12 @@ def run_id(self):
return self._run_id

@rank_zero_only
def log_hyperparams(self, params):
def log_hyperparams(self, params: argparse.Namespace):
for k, v in vars(params).items():
self.experiment.log_param(self.run_id, k, v)

@rank_zero_only
def log_metrics(self, metrics, step=None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
timestamp_ms = int(time() * 1000)
for k, v in metrics.items():
if isinstance(v, str):
Expand All @@ -104,15 +106,15 @@ def save(self):
pass

@rank_zero_only
def finalize(self, status="FINISHED"):
def finalize(self, status: str = "FINISHED"):
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)

@property
def name(self):
def name(self) -> str:
return self.experiment_name

@property
def version(self):
def version(self) -> str:
return self._run_id
Loading