From a5f159b2c7cdecce01a7da5445c89c5ee331310a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 25 Feb 2020 19:52:39 +0000 Subject: [PATCH] Add support for multiple loggers (#903) * Add support for multiple loggers * Fix PEP * Cleanup * Cleanup * Add typing to loggers * Update base.py * Replace duck typing with isinstance check * Update CHANGELOG.md * Update comet experiment type, Switch to abstractmethod in logging.py * Fix test * Add passes to LightningLoggerBase * Update experiment_logging.rst --- CHANGELOG.md | 1 + docs/source/experiment_logging.rst | 23 +++++ pytorch_lightning/loggers/__init__.py | 4 +- pytorch_lightning/loggers/base.py | 104 ++++++++++++++++++----- pytorch_lightning/loggers/comet.py | 29 ++++--- pytorch_lightning/loggers/mlflow.py | 18 ++-- pytorch_lightning/loggers/neptune.py | 92 ++++++++++---------- pytorch_lightning/loggers/tensorboard.py | 26 +++--- pytorch_lightning/loggers/test_tube.py | 27 +++--- pytorch_lightning/loggers/wandb.py | 22 +++-- pytorch_lightning/trainer/logging.py | 18 ++-- pytorch_lightning/trainer/trainer.py | 6 +- tests/test_logging.py | 93 +++++++++++++++----- 13 files changed, 310 insertions(+), 153 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18f10757768c2..ee1349806af6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868)) - Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876)) - Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) +- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) ### Changed diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst index 853f2505e1df8..0ea8930026f05 100644 --- a/docs/source/experiment_logging.rst +++ b/docs/source/experiment_logging.rst @@ -137,3 +137,26 @@ The Wandb logger is available anywhere in your LightningModule some_img = fake_image() self.logger.experiment.add_image('generated_images', some_img, 0) + +Multiple Loggers +^^^^^^^^^^^^^^^^^ + +PyTorch-Lightning supports use of multiple loggers, just pass a list to the `Trainer`. + +.. code-block:: python + + from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger + + logger1 = TensorBoardLogger("tb_logs", name="my_model") + logger2 = TestTubeLogger("tt_logs", name="my_model") + trainer = Trainer(logger=[logger1, logger2]) + +The loggers are available as a list anywhere in your LightningModule + +.. code-block:: python + + class MyModule(pl.LightningModule): + + def __init__(self, ...): + some_img = fake_image() + self.logger.experiment[0].add_image('generated_images', some_img, 0) diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index 23c4b722229c0..b0acb5c1e25cc 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -74,10 +74,10 @@ def any_lightning_module_function_or_hook(...): """ from os import environ -from .base import LightningLoggerBase, rank_zero_only +from .base import LightningLoggerBase, LoggerCollection, rank_zero_only from .tensorboard import TensorBoardLogger -__all__ = ['TensorBoardLogger'] +__all__ = ['TensorBoardLogger', 'LoggerCollection'] try: # needed to prevent ImportError and duplicated logs. diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 1835bba2bf1d5..2bf24acad55ea 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,11 +1,14 @@ -from abc import ABC +import argparse +from abc import ABC, abstractmethod from functools import wraps +from typing import Union, Optional, Dict, Iterable, Any, Callable, List -def rank_zero_only(fn): +def rank_zero_only(fn: Callable): """Decorate a logger method to run it only on the process with rank 0. - :param fn: Function to decorate + Args: + fn: Function to decorate """ @wraps(fn) @@ -23,52 +26,113 @@ def __init__(self): self._rank = 0 @property - def experiment(self): - raise NotImplementedError() + @abstractmethod + def experiment(self) -> Any: + """Return the experiment object associated with this logger""" + pass - def log_metrics(self, metrics, step): + @abstractmethod + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Record metrics. - :param float metric: Dictionary with metric names as keys and measured quanties as values - :param int|None step: Step number at which the metrics should be recorded + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded """ - raise NotImplementedError() + pass - def log_hyperparams(self, params): + @abstractmethod + def log_hyperparams(self, params: argparse.Namespace): """Record hyperparameters. - :param params: argparse.Namespace containing the hyperparameters + Args: + params: argparse.Namespace containing the hyperparameters """ - raise NotImplementedError() + pass def save(self): """Save log data.""" + pass - def finalize(self, status): + def finalize(self, status: str): """Do any processing that is necessary to finalize an experiment. - :param status: Status that the experiment finished with (e.g. success, failed, aborted) + Args: + status: Status that the experiment finished with (e.g. success, failed, aborted) """ + pass def close(self): """Do any cleanup that is necessary to close an experiment.""" + pass @property - def rank(self): + def rank(self) -> int: """Process rank. In general, metrics should only be logged by the process with rank 0.""" return self._rank @rank.setter - def rank(self, value): + def rank(self, value: int): """Set the process rank.""" self._rank = value @property - def name(self): + @abstractmethod + def name(self) -> str: """Return the experiment name.""" - raise NotImplementedError("Sub-classes must provide a name property") + pass @property - def version(self): + @abstractmethod + def version(self) -> Union[int, str]: """Return the experiment version.""" - raise NotImplementedError("Sub-classes must provide a version property") + pass + + +class LoggerCollection(LightningLoggerBase): + """The `LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. + + Args: + logger_iterable: An iterable collection of loggers + """ + + def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): + super().__init__() + self._logger_iterable = logger_iterable + + @property + def experiment(self) -> List[Any]: + return [logger.experiment() for logger in self._logger_iterable] + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + [logger.log_metrics(metrics, step) for logger in self._logger_iterable] + + def log_hyperparams(self, params: argparse.Namespace): + [logger.log_hyperparams(params) for logger in self._logger_iterable] + + def save(self): + [logger.save() for logger in self._logger_iterable] + + def finalize(self, status: str): + [logger.finalize(status) for logger in self._logger_iterable] + + def close(self): + [logger.close() for logger in self._logger_iterable] + + @property + def rank(self) -> int: + return self._rank + + @rank.setter + def rank(self, value: int): + self._rank = value + for logger in self._logger_iterable: + logger.rank = value + + @property + def name(self) -> str: + return '_'.join([str(logger.name) for logger in self._logger_iterable]) + + @property + def version(self) -> str: + return '_'.join([str(logger.version) for logger in self._logger_iterable]) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 112b89a11b7fb..54467cd4f63c4 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -5,13 +5,15 @@ CometLogger ------------- """ - +import argparse from logging import getLogger +from typing import Optional, Union, Dict try: from comet_ml import Experiment as CometExperiment from comet_ml import ExistingExperiment as CometExistingExperiment from comet_ml import OfflineExperiment as CometOfflineExperiment + from comet_ml import BaseExperiment as CometBaseExperiment try: from comet_ml.api import API except ImportError: @@ -33,9 +35,10 @@ class CometLogger(LightningLoggerBase): Log using `comet.ml `_. """ - def __init__(self, api_key=None, save_dir=None, workspace=None, - rest_api_key=None, project_name=None, experiment_name=None, - experiment_key=None, **kwargs): + def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None, + workspace: Optional[str] = None, project_name: Optional[str] = None, + rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None, + experiment_key: Optional[str] = None, **kwargs): r""" Requires either an API Key (online mode) or a local directory path (offline mode) @@ -77,8 +80,8 @@ def __init__(self, api_key=None, save_dir=None, workspace=None, If project name does not already exists Comet.ml will create a new project. rest_api_key (str): Optional. Rest API key found in Comet.ml settings. This is used to determine version number - experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml - + experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml. + experiment_key (str): Optional. If set, restores from existing experiment. """ super().__init__() self._experiment = None @@ -120,7 +123,7 @@ def __init__(self, api_key=None, save_dir=None, workspace=None, logger.exception("Failed to set experiment name for comet.ml logger") @property - def experiment(self): + def experiment(self) -> CometBaseExperiment: r""" Actual comet object. To use comet features do the following. @@ -161,11 +164,11 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): self.experiment.log_parameters(vars(params)) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): # Comet.ml expects metrics to be a dictionary of detached tensors on CPU for key, val in metrics.items(): if is_tensor(val): @@ -177,7 +180,7 @@ def reset_experiment(self): self._experiment = None @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): r""" When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your @@ -190,13 +193,13 @@ def finalize(self, status): self.reset_experiment() @property - def name(self): + def name(self) -> str: return self.experiment.project_name @name.setter - def name(self, value): + def name(self, value: str): self.experiment.set_name(value) @property - def version(self): + def version(self) -> str: return self.experiment.id diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 652e32f675aae..63a5d0850947d 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -23,9 +23,10 @@ def any_lightning_module_function_or_hook(...): self.logger.experiment.whatever_ml_flow_supports(...) """ - +import argparse from logging import getLogger from time import time +from typing import Optional, Dict, Any try: import mlflow @@ -38,7 +39,8 @@ def any_lightning_module_function_or_hook(...): class MLFlowLogger(LightningLoggerBase): - def __init__(self, experiment_name, tracking_uri=None, tags=None): + def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None, + tags: Dict[str, Any] = None): r""" Logs using MLFlow @@ -55,7 +57,7 @@ def __init__(self, experiment_name, tracking_uri=None, tags=None): self.tags = tags @property - def experiment(self): + def experiment(self) -> mlflow.tracking.MlflowClient: r""" Actual mlflow object. To use mlflow features do the following. @@ -85,12 +87,12 @@ def run_id(self): return self._run_id @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): for k, v in vars(params).items(): self.experiment.log_param(self.run_id, k, v) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): timestamp_ms = int(time() * 1000) for k, v in metrics.items(): if isinstance(v, str): @@ -104,15 +106,15 @@ def save(self): pass @rank_zero_only - def finalize(self, status="FINISHED"): + def finalize(self, status: str = "FINISHED"): if status == 'success': status = 'FINISHED' self.experiment.set_terminated(self.run_id, status) @property - def name(self): + def name(self) -> str: return self.experiment_name @property - def version(self): + def version(self) -> str: return self._run_id diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index f81730b60bb4d..e4ebba3f36b5f 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -6,11 +6,13 @@ NeptuneLogger -------------- """ - +import argparse from logging import getLogger +from typing import Optional, List, Dict, Any, Union, Iterable try: import neptune + from neptune.experiments import Experiment except ImportError: raise ImportError('You want to use `neptune` logger which is not installed yet,' ' please install it e.g. `pip install neptune-client`.') @@ -29,9 +31,10 @@ class NeptuneLogger(LightningLoggerBase): To log experiment data in online mode, NeptuneLogger requries an API key: """ - def __init__(self, api_key=None, project_name=None, offline_mode=False, - experiment_name=None, upload_source_files=None, - params=None, properties=None, tags=None, **kwargs): + def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None, + offline_mode: bool = False, experiment_name: Optional[str] = None, + upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, **kwargs): r""" Initialize a neptune.ml logger. @@ -136,7 +139,7 @@ def any_lightning_module_function_or_hook(...): logger.info(f"NeptuneLogger was initialized in {self.mode} mode") @property - def experiment(self): + def experiment(self) -> Experiment: r""" Actual neptune object. To use neptune features do the following. @@ -159,17 +162,17 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): for key, val in vars(params).items(): self.experiment.set_property(f"param__{key}", val) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Log metrics (numeric values) in Neptune experiments - :param float metric: Dictionary with metric names as keys and measured quanties as values - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded, must be strictly increasing """ for key, val in metrics.items(): @@ -182,31 +185,31 @@ def log_metrics(self, metrics, step=None): self.experiment.log_metric(key, x=step, y=val) @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): self.experiment.stop() @property - def name(self): + def name(self) -> str: if self.mode == "offline": return "offline-name" else: return self.experiment.name @property - def version(self): + def version(self) -> str: if self.mode == "offline": return "offline-id-1234" else: return self.experiment.id @rank_zero_only - def log_metric(self, metric_name, metric_value, step=None): + def log_metric(self, metric_name: str, metric_value: float, step: Optional[int] = None): """Log metrics (numeric values) in Neptune experiments - :param str metric_name: The name of log, i.e. mse, loss, accuracy. - :param str metric_value: The value of the log (data-point). - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + metric_name: The name of log, i.e. mse, loss, accuracy. + metric_value: The value of the log (data-point). + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_metric(metric_name, metric_value) @@ -214,13 +217,13 @@ def log_metric(self, metric_name, metric_value, step=None): self.experiment.log_metric(metric_name, x=step, y=metric_value) @rank_zero_only - def log_text(self, log_name, text, step=None): + def log_text(self, log_name: str, text: str, step: Optional[int] = None): """Log text data in Neptune experiment - :param str log_name: The name of log, i.e. mse, my_text_data, timing_info. - :param str text: The value of the log (data-point). - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + log_name: The name of log, i.e. mse, my_text_data, timing_info. + text: The value of the log (data-point). + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_metric(log_name, text) @@ -228,14 +231,14 @@ def log_text(self, log_name, text, step=None): self.experiment.log_metric(log_name, x=step, y=text) @rank_zero_only - def log_image(self, log_name, image, step=None): + def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None): """Log image data in Neptune experiment - :param str log_name: The name of log, i.e. bboxes, visualisations, sample_images. - :param str|PIL.Image|matplotlib.figure.Figure image: The value of the log (data-point). - Can be one of the following types: PIL image, matplotlib.figure.Figure, path to image file (str) - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + log_name: The name of log, i.e. bboxes, visualisations, sample_images. + image (str|PIL.Image|matplotlib.figure.Figure): The value of the log (data-point). + Can be one of the following types: PIL image, matplotlib.figure.Figure, path to image file (str) + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_image(log_name, image) @@ -243,36 +246,35 @@ def log_image(self, log_name, image, step=None): self.experiment.log_image(log_name, x=step, y=image) @rank_zero_only - def log_artifact(self, artifact, destination=None): + def log_artifact(self, artifact: str, destination: Optional[str] = None): """Save an artifact (file) in Neptune experiment storage. - :param str artifact: A path to the file in local filesystem. - :param str|None destination: Optional default None. - A destination path. If None is passed, an artifact file name will be used. - + Args: + artifact: A path to the file in local filesystem. + destination: Optional default None. A destination path. + If None is passed, an artifact file name will be used. """ self.experiment.log_artifact(artifact, destination) @rank_zero_only - def set_property(self, key, value): + def set_property(self, key: str, value: Any): """Set key-value pair as Neptune experiment property. - :param str key: Property key. - :param obj value: New value of a property. - + Args: + key: Property key. + value: New value of a property. """ self.experiment.set_property(key, value) @rank_zero_only - def append_tags(self, tags): + def append_tags(self, tags: Union[str, Iterable[str]]): """appends tags to neptune experiment - :param str|tuple|list(str) tags: Tags to add to the current experiment. - If str is passed, singe tag is added. - If multiple - comma separated - str are passed, all of them are added as tags. - If list of str is passed, all elements of the list are added as tags. - + Args: + tags: Tags to add to the current experiment. If str is passed, singe tag is added. + If multiple - comma separated - str are passed, all of them are added as tags. + If list of str is passed, all elements of the list are added as tags. """ - if not isinstance(tags, (list, set, tuple)): + if not isinstance(tags, Iterable): tags = [tags] # make it as an iterable is if it is not yet self.experiment.append_tags(*tags) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index d7222ee80f0ab..83be246c3a712 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -1,10 +1,12 @@ +import argparse +import csv import os -from warnings import warn from argparse import Namespace -from pkg_resources import parse_version +from typing import Optional, Dict, Union +from warnings import warn import torch -import csv +from pkg_resources import parse_version from torch.utils.tensorboard import SummaryWriter from .base import LightningLoggerBase, rank_zero_only @@ -42,7 +44,7 @@ class TensorBoardLogger(LightningLoggerBase): """ NAME_CSV_TAGS = 'meta_tags.csv' - def __init__(self, save_dir, name="default", version=None, **kwargs): + def __init__(self, save_dir: str, name: str = "default", version: Optional[Union[int, str]] = None, **kwargs): super().__init__() self.save_dir = save_dir self._name = name @@ -53,7 +55,7 @@ def __init__(self, save_dir, name="default", version=None, **kwargs): self.kwargs = kwargs @property - def root_dir(self): + def root_dir(self) -> str: """ Parent directory for all tensorboard checkpoint subdirectories. If the experiment name parameter is None or the empty string, no experiment subdirectory is used @@ -65,7 +67,7 @@ def root_dir(self): return os.path.join(self.save_dir, self.name) @property - def log_dir(self): + def log_dir(self) -> str: """ The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}' but it can be overridden by passing a string value for the constructor's version parameter @@ -77,7 +79,7 @@ def log_dir(self): return log_dir @property - def experiment(self): + def experiment(self) -> SummaryWriter: r""" Actual tensorboard object. To use tensorboard features do the following. @@ -95,7 +97,7 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): if params is None: return @@ -121,7 +123,7 @@ def log_hyperparams(self, params): self.tags.update(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() @@ -151,15 +153,15 @@ def save(self): writer.writerow({'key': k, 'value': v}) @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): self.save() @property - def name(self): + def name(self) -> str: return self._name @property - def version(self): + def version(self) -> int: if self._version is None: self._version = self._get_next_version() return self._version diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 9247efbcb179e..7774c04f356ce 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -1,3 +1,6 @@ +import argparse +from typing import Optional, Dict, Any + try: from test_tube import Experiment except ImportError: @@ -15,8 +18,8 @@ class TestTubeLogger(LightningLoggerBase): __test__ = False def __init__( - self, save_dir, name="default", description=None, debug=False, - version=None, create_git_tag=False + self, save_dir: str, name: str = "default", description: Optional[str] = None, + debug: bool = False, version: Optional[int] = None, create_git_tag: bool = False ): r""" @@ -62,7 +65,7 @@ def any_lightning_module_function_or_hook(...): self._experiment = None @property - def experiment(self): + def experiment(self) -> Experiment: r""" Actual test-tube object. To use test-tube features do the following. @@ -88,13 +91,13 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.experiment.argparse(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.experiment.log(metrics, global_step=step) @@ -106,7 +109,7 @@ def save(self): self.experiment.save() @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.save() @@ -121,24 +124,24 @@ def close(self): exp.close() @property - def rank(self): + def rank(self) -> int: return self._rank @rank.setter - def rank(self, value): + def rank(self, value: int): self._rank = value if self._experiment is not None: self.experiment.rank = value @property - def name(self): + def name(self) -> str: if self._experiment is None: return self._name else: return self.experiment.name @property - def version(self): + def version(self) -> int: if self._experiment is None: return self._version else: @@ -148,12 +151,12 @@ def version(self): # methods to get DDP working. See # https://docs.python.org/3/library/pickle.html#handling-stateful-objects # for more info. - def __getstate__(self): + def __getstate__(self) -> Dict[Any, Any]: state = self.__dict__.copy() state["_experiment"] = self.experiment.get_meta_copy() return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[Any, Any]): self._experiment = state["_experiment"].get_non_ddp_exp() del state["_experiment"] self.__dict__.update(state) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f3ddede7ff45d..e2d77068a4eca 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -5,11 +5,13 @@ WandbLogger ------------- """ - +import argparse import os +from typing import Optional, List, Dict try: import wandb + from wandb.wandb_run import Run except ImportError: raise ImportError('You want to use `wandb` logger which is not installed yet,' ' please install it e.g. `pip install wandb`.') @@ -41,8 +43,10 @@ class WandbLogger(LightningLoggerBase): trainer = Trainer(logger=wandb_logger) """ - def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False, - version=None, project=None, tags=None, experiment=None, entity=None): + def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None, + offline: bool = False, id: Optional[str] = None, anonymous: bool = False, + version: Optional[str] = None, project: Optional[str] = None, + tags: Optional[List[str]] = None, experiment=None, entity=None): super().__init__() self._name = name self._save_dir = save_dir @@ -63,7 +67,7 @@ def __getstate__(self): return state @property - def experiment(self): + def experiment(self) -> Run: r""" Actual wandb object. To use wandb features do the following. @@ -85,11 +89,11 @@ def watch(self, model, log="gradients", log_freq=100): wandb.watch(model, log, log_freq) @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): self.experiment.config.update(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): metrics["global_step"] = step self.experiment.log(metrics) @@ -97,7 +101,7 @@ def save(self): pass @rank_zero_only - def finalize(self, status='success'): + def finalize(self, status: str = 'success'): try: exit_code = 0 if status == 'success' else 1 wandb.join(exit_code) @@ -105,9 +109,9 @@ def finalize(self, status='success'): wandb.join() @property - def name(self): + def name(self) -> str: return self.experiment.project_name() @property - def version(self): + def version(self) -> str: return self.experiment.id diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 34b1c114b338a..20a6673d69aa6 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,9 +1,10 @@ from abc import ABC +from typing import Iterable import torch from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection class TrainerLoggingMixin(ABC): @@ -34,16 +35,21 @@ def configure_logger(self, logger): elif logger is False: self.logger = None else: - self.logger = logger + if isinstance(logger, Iterable): + self.logger = LoggerCollection(logger) + else: + self.logger = logger self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step - :param metrics (dict): Metric values - :param grad_norm_dic (dict): Gradient norms - :param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` + + Args: + metrics (dict): Metric values + grad_norm_dic (dict): Gradient norms + step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ # add gpu memory if self.on_gpu and self.log_gpu_memory: @@ -91,8 +97,6 @@ def process_output(self, output, train=False): """Reduces output according to the training mode. Separates loss from logging and tqdm metrics - :param output: - :return: """ # --------------- # EXTRACT CALLBACK KEYS diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 038cdc80b8383..815f2d9deebfa 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2,7 +2,7 @@ import sys import warnings import logging as log -from typing import Union, Optional, List, Dict, Tuple +from typing import Union, Optional, List, Dict, Tuple, Iterable import torch import torch.distributed as dist @@ -66,7 +66,7 @@ class Trainer(TrainerIOMixin, def __init__( self, - logger: Union[LightningLoggerBase, bool] = True, + logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, default_save_path: Optional[str] = None, @@ -117,7 +117,7 @@ def __init__( Customize every aspect of training via flags Args: - logger: Logger for experiment tracking. + logger: Logger (or iterable collection of loggers) for experiment tracking. Example:: from pytorch_lightning.loggers import TensorBoardLogger diff --git a/tests/test_logging.py b/tests/test_logging.py index 93f0b46f2b197..1ad487a75ad9d 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -330,34 +330,39 @@ def test_tensorboard_log_hyperparams(tmpdir): logger.log_hyperparams(hparams) -def test_custom_logger(tmpdir): - class CustomLogger(LightningLoggerBase): - def __init__(self): - super().__init__() - self.hparams_logged = None - self.metrics_logged = None - self.finalized = False +class CustomLogger(LightningLoggerBase): + def __init__(self): + super().__init__() + self.hparams_logged = None + self.metrics_logged = None + self.finalized = False + + @property + def experiment(self): + return 'test' - @rank_zero_only - def log_hyperparams(self, params): - self.hparams_logged = params + @rank_zero_only + def log_hyperparams(self, params): + self.hparams_logged = params - @rank_zero_only - def log_metrics(self, metrics, step): - self.metrics_logged = metrics + @rank_zero_only + def log_metrics(self, metrics, step): + self.metrics_logged = metrics - @rank_zero_only - def finalize(self, status): - self.finalized_status = status + @rank_zero_only + def finalize(self, status): + self.finalized_status = status - @property - def name(self): - return "name" + @property + def name(self): + return "name" - @property - def version(self): - return "1" + @property + def version(self): + return "1" + +def test_custom_logger(tmpdir): hparams = tutils.get_hparams() model = LightningTestModel(hparams) @@ -378,6 +383,50 @@ def version(self): assert logger.finalized_status == "success" +def test_multiple_loggers(tmpdir): + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + + logger1 = CustomLogger() + logger2 = CustomLogger() + + trainer_options = dict( + max_epochs=1, + train_percent_check=0.05, + logger=[logger1, logger2], + default_save_path=tmpdir + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training failed" + + assert logger1.hparams_logged == hparams + assert logger1.metrics_logged != {} + assert logger1.finalized_status == "success" + + assert logger2.hparams_logged == hparams + assert logger2.metrics_logged != {} + assert logger2.finalized_status == "success" + + +def test_multiple_loggers_pickle(tmpdir): + """Verify that pickling trainer with multiple loggers works.""" + + logger1 = CustomLogger() + logger2 = CustomLogger() + + trainer_options = dict(max_epochs=1, logger=[logger1, logger2]) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}, 0) + + assert logger1.metrics_logged != {} + assert logger2.metrics_logged != {} + + def test_adding_step_key(tmpdir): logged_step = 0