From 7f6eceaf427250c80575d2ab70196a6053ff5c98 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 20 Feb 2020 08:07:54 +0000 Subject: [PATCH 01/12] Add support for multiple loggers --- pytorch_lightning/loggers/__init__.py | 2 +- pytorch_lightning/loggers/base.py | 64 ++++++++++++++++++- pytorch_lightning/trainer/logging.py | 8 ++- tests/test_logging.py | 89 ++++++++++++++++++++------- 4 files changed, 136 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index 23c4b722229c0..df248f2c92aec 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -74,7 +74,7 @@ def any_lightning_module_function_or_hook(...): """ from os import environ -from .base import LightningLoggerBase, rank_zero_only +from .base import LightningLoggerBase, LightningLoggerList, rank_zero_only from .tensorboard import TensorBoardLogger __all__ = ['TensorBoardLogger'] diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 1835bba2bf1d5..d08fa60d1f636 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,8 +19,9 @@ def wrapped_fn(self, *args, **kwargs): class LightningLoggerBase(ABC): """Base class for experiment loggers.""" - def __init__(self): + def __init__(self, priority=1000): self._rank = 0 + self._priority = priority @property def experiment(self): @@ -29,7 +30,7 @@ def experiment(self): def log_metrics(self, metrics, step): """Record metrics. - :param float metric: Dictionary with metric names as keys and measured quanties as values + :param float metrics: Dictionary with metric names as keys and measured quantities as values :param int|None step: Step number at which the metrics should be recorded """ raise NotImplementedError() @@ -53,6 +54,17 @@ def finalize(self, status): def close(self): """Do any cleanup that is necessary to close an experiment.""" + @property + def priority(self): + return self._priority + + @priority.setter + def priority(self, value): + self._priority = value + + def as_main_logger(self): + self.priority = 1000 + @property def rank(self): """Process rank. In general, metrics should only be logged by the process with rank 0.""" @@ -72,3 +84,51 @@ def name(self): def version(self): """Return the experiment version.""" raise NotImplementedError("Sub-classes must provide a version property") + + +class LightningLoggerList(LightningLoggerBase): + """The `LoggerList` class is used to iterate all logging actions over the given `logger_list`. + + :param logger_list: An iterable collection of loggers + """ + + def __init__(self, logger_list): + super().__init__() + self._logger_list = logger_list + + @property + def experiment(self): + return [logger.experiment() for logger in self._logger_list] + + def log_metrics(self, metrics, step): + return [logger.log_metrics(metrics, step) for logger in self._logger_list] + + def log_hyperparams(self, params): + return [logger.log_hyperparams(params) for logger in self._logger_list] + + def save(self): + return [logger.save() for logger in self._logger_list] + + def finalize(self, status): + return [logger.finalize(status) for logger in self._logger_list] + + def close(self): + return [logger.close() for logger in self._logger_list] + + @property + def rank(self): + return self._rank + + @rank.setter + def rank(self, value): + self._rank = value + for logger in self._logger_list: + logger.rank = value + + @property + def name(self): + return '_'.join([str(logger.name) for logger in self._logger_list]) + + @property + def version(self): + return '_'.join([str(logger.version) for logger in self._logger_list]) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 34b1c114b338a..b8125b389b235 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -3,7 +3,7 @@ import torch from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerList class TrainerLoggingMixin(ABC): @@ -34,7 +34,11 @@ def configure_logger(self, logger): elif logger is False: self.logger = None else: - self.logger = logger + try: + _ = iter(logger) + self.logger = LightningLoggerList(logger) # can call iter on logger, make it a logger list + except TypeError: + self.logger = logger # can't call iter, must just be a regular logger self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): diff --git a/tests/test_logging.py b/tests/test_logging.py index 0d4104ef7a34b..1bbeb5bc887f8 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -330,34 +330,35 @@ def test_tensorboard_log_hyperparams(tmpdir): logger.log_hyperparams(hparams) -def test_custom_logger(tmpdir): - class CustomLogger(LightningLoggerBase): - def __init__(self): - super().__init__() - self.hparams_logged = None - self.metrics_logged = None - self.finalized = False +class CustomLogger(LightningLoggerBase): + def __init__(self): + super().__init__() + self.hparams_logged = None + self.metrics_logged = None + self.finalized = False + + @rank_zero_only + def log_hyperparams(self, params): + self.hparams_logged = params - @rank_zero_only - def log_hyperparams(self, params): - self.hparams_logged = params + @rank_zero_only + def log_metrics(self, metrics, step): + self.metrics_logged = metrics - @rank_zero_only - def log_metrics(self, metrics, step): - self.metrics_logged = metrics + @rank_zero_only + def finalize(self, status): + self.finalized_status = status - @rank_zero_only - def finalize(self, status): - self.finalized_status = status + @property + def name(self): + return "name" - @property - def name(self): - return "name" + @property + def version(self): + return "1" - @property - def version(self): - return "1" +def test_custom_logger(tmpdir): hparams = tutils.get_hparams() model = LightningTestModel(hparams) @@ -378,6 +379,50 @@ def version(self): assert logger.finalized_status == "success" +def test_multiple_loggers(tmpdir): + hparams = tutils.get_hparams() + model = LightningTestModel(hparams) + + logger1 = CustomLogger() + logger2 = CustomLogger() + + trainer_options = dict( + max_epochs=1, + train_percent_check=0.05, + logger=[logger1, logger2], + default_save_path=tmpdir + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + assert result == 1, "Training failed" + + assert logger1.hparams_logged == hparams + assert logger1.metrics_logged != {} + assert logger1.finalized_status == "success" + + assert logger2.hparams_logged == hparams + assert logger2.metrics_logged != {} + assert logger2.finalized_status == "success" + + +def test_multiple_loggers_pickle(tmpdir): + """Verify that pickling trainer with multiple loggers works.""" + + logger1 = CustomLogger() + logger2 = CustomLogger() + + trainer_options = dict(max_epochs=1, logger=[logger1, logger2]) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}, 0) + + assert logger1.metrics_logged != {} + assert logger2.metrics_logged != {} + + def test_adding_step_key(tmpdir): logged_step = 0 From 0ca4bd14dca30b43950713a58cd42e970d17bac8 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 20 Feb 2020 08:11:53 +0000 Subject: [PATCH 02/12] Fix PEP --- pytorch_lightning/trainer/logging.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index b8125b389b235..153c8544dba33 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -36,9 +36,11 @@ def configure_logger(self, logger): else: try: _ = iter(logger) - self.logger = LightningLoggerList(logger) # can call iter on logger, make it a logger list + # can call iter on logger, make it a logger list + self.logger = LightningLoggerList(logger) except TypeError: - self.logger = logger # can't call iter, must just be a regular logger + # can't call iter, must just be a regular logger + self.logger = logger self.logger.rank = 0 def log_metrics(self, metrics, grad_norm_dic, step=None): From 7a10da1cd19b779222d6d0c5501500ff313496cf Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 20 Feb 2020 08:13:25 +0000 Subject: [PATCH 03/12] Cleanup --- pytorch_lightning/loggers/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index d08fa60d1f636..c45ea79412da3 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -19,9 +19,8 @@ def wrapped_fn(self, *args, **kwargs): class LightningLoggerBase(ABC): """Base class for experiment loggers.""" - def __init__(self, priority=1000): + def __init__(self): self._rank = 0 - self._priority = priority @property def experiment(self): From fbf7d7ec176c21c823f021496768a8ed7659ae27 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 20 Feb 2020 08:13:54 +0000 Subject: [PATCH 04/12] Cleanup --- pytorch_lightning/loggers/base.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index c45ea79412da3..15138835d9e16 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -53,17 +53,6 @@ def finalize(self, status): def close(self): """Do any cleanup that is necessary to close an experiment.""" - @property - def priority(self): - return self._priority - - @priority.setter - def priority(self, value): - self._priority = value - - def as_main_logger(self): - self.priority = 1000 - @property def rank(self): """Process rank. In general, metrics should only be logged by the process with rank 0.""" From 065bce1079cee82f13d3247afa42e8f19890ab6f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 24 Feb 2020 13:39:55 +0000 Subject: [PATCH 05/12] Add typing to loggers --- pytorch_lightning/loggers/__init__.py | 4 +- pytorch_lightning/loggers/base.py | 79 ++++++++++---------- pytorch_lightning/loggers/comet.py | 28 ++++---- pytorch_lightning/loggers/mlflow.py | 18 ++--- pytorch_lightning/loggers/neptune.py | 92 ++++++++++++------------ pytorch_lightning/loggers/tensorboard.py | 26 +++---- pytorch_lightning/loggers/test_tube.py | 27 +++---- pytorch_lightning/loggers/wandb.py | 22 +++--- pytorch_lightning/trainer/logging.py | 14 ++-- pytorch_lightning/trainer/trainer.py | 6 +- 10 files changed, 169 insertions(+), 147 deletions(-) diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index df248f2c92aec..b0acb5c1e25cc 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -74,10 +74,10 @@ def any_lightning_module_function_or_hook(...): """ from os import environ -from .base import LightningLoggerBase, LightningLoggerList, rank_zero_only +from .base import LightningLoggerBase, LoggerCollection, rank_zero_only from .tensorboard import TensorBoardLogger -__all__ = ['TensorBoardLogger'] +__all__ = ['TensorBoardLogger', 'LoggerCollection'] try: # needed to prevent ImportError and duplicated logs. diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 15138835d9e16..f9c44475ef478 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,11 +1,14 @@ +import argparse from abc import ABC from functools import wraps +from typing import Union, Optional, Dict, Iterable, Any, Callable -def rank_zero_only(fn): +def rank_zero_only(fn: Callable): """Decorate a logger method to run it only on the process with rank 0. - :param fn: Function to decorate + Args: + fn: Function to decorate """ @wraps(fn) @@ -23,100 +26,104 @@ def __init__(self): self._rank = 0 @property - def experiment(self): + def experiment(self) -> Any: raise NotImplementedError() - def log_metrics(self, metrics, step): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Record metrics. - :param float metrics: Dictionary with metric names as keys and measured quantities as values - :param int|None step: Step number at which the metrics should be recorded + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded """ raise NotImplementedError() - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): """Record hyperparameters. - :param params: argparse.Namespace containing the hyperparameters + Args: + params: argparse.Namespace containing the hyperparameters """ raise NotImplementedError() def save(self): """Save log data.""" - def finalize(self, status): + def finalize(self, status: str): """Do any processing that is necessary to finalize an experiment. - :param status: Status that the experiment finished with (e.g. success, failed, aborted) + Args: + status: Status that the experiment finished with (e.g. success, failed, aborted) """ def close(self): """Do any cleanup that is necessary to close an experiment.""" @property - def rank(self): + def rank(self) -> int: """Process rank. In general, metrics should only be logged by the process with rank 0.""" return self._rank @rank.setter - def rank(self, value): + def rank(self, value: int): """Set the process rank.""" self._rank = value @property - def name(self): + def name(self) -> str: """Return the experiment name.""" raise NotImplementedError("Sub-classes must provide a name property") @property - def version(self): + def version(self) -> Union[int, str]: """Return the experiment version.""" raise NotImplementedError("Sub-classes must provide a version property") -class LightningLoggerList(LightningLoggerBase): - """The `LoggerList` class is used to iterate all logging actions over the given `logger_list`. +class LoggerCollection(LightningLoggerBase): + """The `LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. - :param logger_list: An iterable collection of loggers + Args: + logger_iterable: An iterable collection of loggers """ - def __init__(self, logger_list): + def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): super().__init__() - self._logger_list = logger_list + self._logger_iterable = logger_iterable @property - def experiment(self): - return [logger.experiment() for logger in self._logger_list] + def experiment(self) -> Iterable[Any]: + return [logger.experiment() for logger in self._logger_iterable] - def log_metrics(self, metrics, step): - return [logger.log_metrics(metrics, step) for logger in self._logger_list] + def log_metrics(self, metrics: Dict[str, float], step: Optional[int]): + [logger.log_metrics(metrics, step) for logger in self._logger_iterable] - def log_hyperparams(self, params): - return [logger.log_hyperparams(params) for logger in self._logger_list] + def log_hyperparams(self, params: argparse.Namespace): + [logger.log_hyperparams(params) for logger in self._logger_iterable] def save(self): - return [logger.save() for logger in self._logger_list] + [logger.save() for logger in self._logger_iterable] - def finalize(self, status): - return [logger.finalize(status) for logger in self._logger_list] + def finalize(self, status: str): + [logger.finalize(status) for logger in self._logger_iterable] def close(self): - return [logger.close() for logger in self._logger_list] + [logger.close() for logger in self._logger_iterable] @property - def rank(self): + def rank(self) -> int: return self._rank @rank.setter - def rank(self, value): + def rank(self, value: int): self._rank = value - for logger in self._logger_list: + for logger in self._logger_iterable: logger.rank = value @property - def name(self): - return '_'.join([str(logger.name) for logger in self._logger_list]) + def name(self) -> str: + return '_'.join([str(logger.name) for logger in self._logger_iterable]) @property - def version(self): - return '_'.join([str(logger.version) for logger in self._logger_list]) + def version(self) -> str: + return '_'.join([str(logger.version) for logger in self._logger_iterable]) diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index fd8d456f34f41..cf56493a687bf 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -5,8 +5,9 @@ CometLogger ------------- """ - +import argparse from logging import getLogger +from typing import Optional, Union, Dict try: from comet_ml import Experiment as CometExperiment @@ -33,9 +34,10 @@ class CometLogger(LightningLoggerBase): Log using `comet.ml `_. """ - def __init__(self, api_key=None, save_dir=None, workspace=None, - rest_api_key=None, project_name=None, experiment_name=None, - experiment_key=None, **kwargs): + def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None, + workspace: Optional[str] = None, project_name: Optional[str] = None, + rest_api_key: Optional[str] = None, experiment_name: Optional[str] = None, + experiment_key: Optional[str] = None, **kwargs): r""" Requires either an API Key (online mode) or a local directory path (offline mode) @@ -77,8 +79,8 @@ def __init__(self, api_key=None, save_dir=None, workspace=None, If project name does not already exists Comet.ml will create a new project. rest_api_key (str): Optional. Rest API key found in Comet.ml settings. This is used to determine version number - experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml - + experiment_name (str): Optional. String representing the name for this particular experiment on Comet.ml. + experiment_key (str): Optional. If set, restores from existing experiment. """ super().__init__() self._experiment = None @@ -120,7 +122,7 @@ def __init__(self, api_key=None, save_dir=None, workspace=None, logger.exception("Failed to set experiment name for comet.ml logger") @property - def experiment(self): + def experiment(self) -> Union[CometOfflineExperiment, CometExistingExperiment, CometExperiment]: r""" Actual comet object. To use comet features do the following. @@ -161,11 +163,11 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): self.experiment.log_parameters(vars(params)) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): # Comet.ml expects metrics to be a dictionary of detached tensors on CPU for key, val in metrics.items(): if is_tensor(val): @@ -177,7 +179,7 @@ def reset_experiment(self): self._experiment = None @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): r""" When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your @@ -190,13 +192,13 @@ def finalize(self, status): self.reset_experiment() @property - def name(self): + def name(self) -> str: return self.experiment.project_name @name.setter - def name(self, value): + def name(self, value: str): self.experiment.set_name(value) @property - def version(self): + def version(self) -> str: return self.experiment.id diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index 652e32f675aae..63a5d0850947d 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -23,9 +23,10 @@ def any_lightning_module_function_or_hook(...): self.logger.experiment.whatever_ml_flow_supports(...) """ - +import argparse from logging import getLogger from time import time +from typing import Optional, Dict, Any try: import mlflow @@ -38,7 +39,8 @@ def any_lightning_module_function_or_hook(...): class MLFlowLogger(LightningLoggerBase): - def __init__(self, experiment_name, tracking_uri=None, tags=None): + def __init__(self, experiment_name: str, tracking_uri: Optional[str] = None, + tags: Dict[str, Any] = None): r""" Logs using MLFlow @@ -55,7 +57,7 @@ def __init__(self, experiment_name, tracking_uri=None, tags=None): self.tags = tags @property - def experiment(self): + def experiment(self) -> mlflow.tracking.MlflowClient: r""" Actual mlflow object. To use mlflow features do the following. @@ -85,12 +87,12 @@ def run_id(self): return self._run_id @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): for k, v in vars(params).items(): self.experiment.log_param(self.run_id, k, v) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): timestamp_ms = int(time() * 1000) for k, v in metrics.items(): if isinstance(v, str): @@ -104,15 +106,15 @@ def save(self): pass @rank_zero_only - def finalize(self, status="FINISHED"): + def finalize(self, status: str = "FINISHED"): if status == 'success': status = 'FINISHED' self.experiment.set_terminated(self.run_id, status) @property - def name(self): + def name(self) -> str: return self.experiment_name @property - def version(self): + def version(self) -> str: return self._run_id diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index f81730b60bb4d..e4ebba3f36b5f 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -6,11 +6,13 @@ NeptuneLogger -------------- """ - +import argparse from logging import getLogger +from typing import Optional, List, Dict, Any, Union, Iterable try: import neptune + from neptune.experiments import Experiment except ImportError: raise ImportError('You want to use `neptune` logger which is not installed yet,' ' please install it e.g. `pip install neptune-client`.') @@ -29,9 +31,10 @@ class NeptuneLogger(LightningLoggerBase): To log experiment data in online mode, NeptuneLogger requries an API key: """ - def __init__(self, api_key=None, project_name=None, offline_mode=False, - experiment_name=None, upload_source_files=None, - params=None, properties=None, tags=None, **kwargs): + def __init__(self, api_key: Optional[str] = None, project_name: Optional[str] = None, + offline_mode: bool = False, experiment_name: Optional[str] = None, + upload_source_files: Optional[List[str]] = None, params: Optional[Dict[str, Any]] = None, + properties: Optional[Dict[str, Any]] = None, tags: Optional[List[str]] = None, **kwargs): r""" Initialize a neptune.ml logger. @@ -136,7 +139,7 @@ def any_lightning_module_function_or_hook(...): logger.info(f"NeptuneLogger was initialized in {self.mode} mode") @property - def experiment(self): + def experiment(self) -> Experiment: r""" Actual neptune object. To use neptune features do the following. @@ -159,17 +162,17 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): for key, val in vars(params).items(): self.experiment.set_property(f"param__{key}", val) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Log metrics (numeric values) in Neptune experiments - :param float metric: Dictionary with metric names as keys and measured quanties as values - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded, must be strictly increasing """ for key, val in metrics.items(): @@ -182,31 +185,31 @@ def log_metrics(self, metrics, step=None): self.experiment.log_metric(key, x=step, y=val) @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): self.experiment.stop() @property - def name(self): + def name(self) -> str: if self.mode == "offline": return "offline-name" else: return self.experiment.name @property - def version(self): + def version(self) -> str: if self.mode == "offline": return "offline-id-1234" else: return self.experiment.id @rank_zero_only - def log_metric(self, metric_name, metric_value, step=None): + def log_metric(self, metric_name: str, metric_value: float, step: Optional[int] = None): """Log metrics (numeric values) in Neptune experiments - :param str metric_name: The name of log, i.e. mse, loss, accuracy. - :param str metric_value: The value of the log (data-point). - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + metric_name: The name of log, i.e. mse, loss, accuracy. + metric_value: The value of the log (data-point). + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_metric(metric_name, metric_value) @@ -214,13 +217,13 @@ def log_metric(self, metric_name, metric_value, step=None): self.experiment.log_metric(metric_name, x=step, y=metric_value) @rank_zero_only - def log_text(self, log_name, text, step=None): + def log_text(self, log_name: str, text: str, step: Optional[int] = None): """Log text data in Neptune experiment - :param str log_name: The name of log, i.e. mse, my_text_data, timing_info. - :param str text: The value of the log (data-point). - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + log_name: The name of log, i.e. mse, my_text_data, timing_info. + text: The value of the log (data-point). + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_metric(log_name, text) @@ -228,14 +231,14 @@ def log_text(self, log_name, text, step=None): self.experiment.log_metric(log_name, x=step, y=text) @rank_zero_only - def log_image(self, log_name, image, step=None): + def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None): """Log image data in Neptune experiment - :param str log_name: The name of log, i.e. bboxes, visualisations, sample_images. - :param str|PIL.Image|matplotlib.figure.Figure image: The value of the log (data-point). - Can be one of the following types: PIL image, matplotlib.figure.Figure, path to image file (str) - :param int|None step: Step number at which the metrics should be recorded, must be strictly increasing - + Args: + log_name: The name of log, i.e. bboxes, visualisations, sample_images. + image (str|PIL.Image|matplotlib.figure.Figure): The value of the log (data-point). + Can be one of the following types: PIL image, matplotlib.figure.Figure, path to image file (str) + step: Step number at which the metrics should be recorded, must be strictly increasing """ if step is None: self.experiment.log_image(log_name, image) @@ -243,36 +246,35 @@ def log_image(self, log_name, image, step=None): self.experiment.log_image(log_name, x=step, y=image) @rank_zero_only - def log_artifact(self, artifact, destination=None): + def log_artifact(self, artifact: str, destination: Optional[str] = None): """Save an artifact (file) in Neptune experiment storage. - :param str artifact: A path to the file in local filesystem. - :param str|None destination: Optional default None. - A destination path. If None is passed, an artifact file name will be used. - + Args: + artifact: A path to the file in local filesystem. + destination: Optional default None. A destination path. + If None is passed, an artifact file name will be used. """ self.experiment.log_artifact(artifact, destination) @rank_zero_only - def set_property(self, key, value): + def set_property(self, key: str, value: Any): """Set key-value pair as Neptune experiment property. - :param str key: Property key. - :param obj value: New value of a property. - + Args: + key: Property key. + value: New value of a property. """ self.experiment.set_property(key, value) @rank_zero_only - def append_tags(self, tags): + def append_tags(self, tags: Union[str, Iterable[str]]): """appends tags to neptune experiment - :param str|tuple|list(str) tags: Tags to add to the current experiment. - If str is passed, singe tag is added. - If multiple - comma separated - str are passed, all of them are added as tags. - If list of str is passed, all elements of the list are added as tags. - + Args: + tags: Tags to add to the current experiment. If str is passed, singe tag is added. + If multiple - comma separated - str are passed, all of them are added as tags. + If list of str is passed, all elements of the list are added as tags. """ - if not isinstance(tags, (list, set, tuple)): + if not isinstance(tags, Iterable): tags = [tags] # make it as an iterable is if it is not yet self.experiment.append_tags(*tags) diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index d7222ee80f0ab..83be246c3a712 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -1,10 +1,12 @@ +import argparse +import csv import os -from warnings import warn from argparse import Namespace -from pkg_resources import parse_version +from typing import Optional, Dict, Union +from warnings import warn import torch -import csv +from pkg_resources import parse_version from torch.utils.tensorboard import SummaryWriter from .base import LightningLoggerBase, rank_zero_only @@ -42,7 +44,7 @@ class TensorBoardLogger(LightningLoggerBase): """ NAME_CSV_TAGS = 'meta_tags.csv' - def __init__(self, save_dir, name="default", version=None, **kwargs): + def __init__(self, save_dir: str, name: str = "default", version: Optional[Union[int, str]] = None, **kwargs): super().__init__() self.save_dir = save_dir self._name = name @@ -53,7 +55,7 @@ def __init__(self, save_dir, name="default", version=None, **kwargs): self.kwargs = kwargs @property - def root_dir(self): + def root_dir(self) -> str: """ Parent directory for all tensorboard checkpoint subdirectories. If the experiment name parameter is None or the empty string, no experiment subdirectory is used @@ -65,7 +67,7 @@ def root_dir(self): return os.path.join(self.save_dir, self.name) @property - def log_dir(self): + def log_dir(self) -> str: """ The directory for this run's tensorboard checkpoint. By default, it is named 'version_${self.version}' but it can be overridden by passing a string value for the constructor's version parameter @@ -77,7 +79,7 @@ def log_dir(self): return log_dir @property - def experiment(self): + def experiment(self) -> SummaryWriter: r""" Actual tensorboard object. To use tensorboard features do the following. @@ -95,7 +97,7 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): if params is None: return @@ -121,7 +123,7 @@ def log_hyperparams(self, params): self.tags.update(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): for k, v in metrics.items(): if isinstance(v, torch.Tensor): v = v.item() @@ -151,15 +153,15 @@ def save(self): writer.writerow({'key': k, 'value': v}) @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): self.save() @property - def name(self): + def name(self) -> str: return self._name @property - def version(self): + def version(self) -> int: if self._version is None: self._version = self._get_next_version() return self._version diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 9247efbcb179e..7774c04f356ce 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -1,3 +1,6 @@ +import argparse +from typing import Optional, Dict, Any + try: from test_tube import Experiment except ImportError: @@ -15,8 +18,8 @@ class TestTubeLogger(LightningLoggerBase): __test__ = False def __init__( - self, save_dir, name="default", description=None, debug=False, - version=None, create_git_tag=False + self, save_dir: str, name: str = "default", description: Optional[str] = None, + debug: bool = False, version: Optional[int] = None, create_git_tag: bool = False ): r""" @@ -62,7 +65,7 @@ def any_lightning_module_function_or_hook(...): self._experiment = None @property - def experiment(self): + def experiment(self) -> Experiment: r""" Actual test-tube object. To use test-tube features do the following. @@ -88,13 +91,13 @@ def experiment(self): return self._experiment @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.experiment.argparse(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.experiment.log(metrics, global_step=step) @@ -106,7 +109,7 @@ def save(self): self.experiment.save() @rank_zero_only - def finalize(self, status): + def finalize(self, status: str): # TODO: HACK figure out where this is being set to true self.experiment.debug = self.debug self.save() @@ -121,24 +124,24 @@ def close(self): exp.close() @property - def rank(self): + def rank(self) -> int: return self._rank @rank.setter - def rank(self, value): + def rank(self, value: int): self._rank = value if self._experiment is not None: self.experiment.rank = value @property - def name(self): + def name(self) -> str: if self._experiment is None: return self._name else: return self.experiment.name @property - def version(self): + def version(self) -> int: if self._experiment is None: return self._version else: @@ -148,12 +151,12 @@ def version(self): # methods to get DDP working. See # https://docs.python.org/3/library/pickle.html#handling-stateful-objects # for more info. - def __getstate__(self): + def __getstate__(self) -> Dict[Any, Any]: state = self.__dict__.copy() state["_experiment"] = self.experiment.get_meta_copy() return state - def __setstate__(self, state): + def __setstate__(self, state: Dict[Any, Any]): self._experiment = state["_experiment"].get_non_ddp_exp() del state["_experiment"] self.__dict__.update(state) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f3ddede7ff45d..e2d77068a4eca 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -5,11 +5,13 @@ WandbLogger ------------- """ - +import argparse import os +from typing import Optional, List, Dict try: import wandb + from wandb.wandb_run import Run except ImportError: raise ImportError('You want to use `wandb` logger which is not installed yet,' ' please install it e.g. `pip install wandb`.') @@ -41,8 +43,10 @@ class WandbLogger(LightningLoggerBase): trainer = Trainer(logger=wandb_logger) """ - def __init__(self, name=None, save_dir=None, offline=False, id=None, anonymous=False, - version=None, project=None, tags=None, experiment=None, entity=None): + def __init__(self, name: Optional[str] = None, save_dir: Optional[str] = None, + offline: bool = False, id: Optional[str] = None, anonymous: bool = False, + version: Optional[str] = None, project: Optional[str] = None, + tags: Optional[List[str]] = None, experiment=None, entity=None): super().__init__() self._name = name self._save_dir = save_dir @@ -63,7 +67,7 @@ def __getstate__(self): return state @property - def experiment(self): + def experiment(self) -> Run: r""" Actual wandb object. To use wandb features do the following. @@ -85,11 +89,11 @@ def watch(self, model, log="gradients", log_freq=100): wandb.watch(model, log, log_freq) @rank_zero_only - def log_hyperparams(self, params): + def log_hyperparams(self, params: argparse.Namespace): self.experiment.config.update(params) @rank_zero_only - def log_metrics(self, metrics, step=None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): metrics["global_step"] = step self.experiment.log(metrics) @@ -97,7 +101,7 @@ def save(self): pass @rank_zero_only - def finalize(self, status='success'): + def finalize(self, status: str = 'success'): try: exit_code = 0 if status == 'success' else 1 wandb.join(exit_code) @@ -105,9 +109,9 @@ def finalize(self, status='success'): wandb.join() @property - def name(self): + def name(self) -> str: return self.experiment.project_name() @property - def version(self): + def version(self) -> str: return self.experiment.id diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 153c8544dba33..1caa68e02fbea 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -3,7 +3,7 @@ import torch from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerList +from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection class TrainerLoggingMixin(ABC): @@ -37,7 +37,7 @@ def configure_logger(self, logger): try: _ = iter(logger) # can call iter on logger, make it a logger list - self.logger = LightningLoggerList(logger) + self.logger = LoggerCollection(logger) except TypeError: # can't call iter, must just be a regular logger self.logger = logger @@ -47,9 +47,11 @@ def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step - :param metrics (dict): Metric values - :param grad_norm_dic (dict): Gradient norms - :param step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` + + Args: + metrics (dict): Metric values + grad_norm_dic (dict): Gradient norms + step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` """ # add gpu memory if self.on_gpu and self.log_gpu_memory: @@ -97,8 +99,6 @@ def process_output(self, output, train=False): """Reduces output according to the training mode. Separates loss from logging and tqdm metrics - :param output: - :return: """ # --------------- # EXTRACT CALLBACK KEYS diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4d204bd287e36..aef65edf957cd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2,7 +2,7 @@ import sys import warnings import logging as log -from typing import Union, Optional, List, Dict, Tuple +from typing import Union, Optional, List, Dict, Tuple, Iterable import torch import torch.distributed as dist @@ -66,7 +66,7 @@ class Trainer(TrainerIOMixin, def __init__( self, - logger: Union[LightningLoggerBase, bool] = True, + logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, checkpoint_callback: Union[ModelCheckpoint, bool] = True, early_stop_callback: Optional[Union[EarlyStopping, bool]] = None, default_save_path: Optional[str] = None, @@ -115,7 +115,7 @@ def __init__( Customize every aspect of training via flags Args: - logger: Logger for experiment tracking. + logger: Logger (or iterable collection of loggers) for experiment tracking. Example:: from pytorch_lightning.loggers import TensorBoardLogger From 43c7dcbe4bc2fe5a57014d72b6dfbd8e182c20bd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 24 Feb 2020 13:42:11 +0000 Subject: [PATCH 06/12] Update base.py --- pytorch_lightning/loggers/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index f9c44475ef478..aeabcb6adeaa8 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,7 +1,7 @@ import argparse from abc import ABC from functools import wraps -from typing import Union, Optional, Dict, Iterable, Any, Callable +from typing import Union, Optional, Dict, Iterable, Any, Callable, List def rank_zero_only(fn: Callable): @@ -92,10 +92,10 @@ def __init__(self, logger_iterable: Iterable[LightningLoggerBase]): self._logger_iterable = logger_iterable @property - def experiment(self) -> Iterable[Any]: + def experiment(self) -> List[Any]: return [logger.experiment() for logger in self._logger_iterable] - def log_metrics(self, metrics: Dict[str, float], step: Optional[int]): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): [logger.log_metrics(metrics, step) for logger in self._logger_iterable] def log_hyperparams(self, params: argparse.Namespace): From 16788bed4d917696586876bec95870aff2817fcf Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 24 Feb 2020 14:04:00 +0000 Subject: [PATCH 07/12] Replace duck typing with isinstance check --- pytorch_lightning/trainer/logging.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 1caa68e02fbea..20a6673d69aa6 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,4 +1,5 @@ from abc import ABC +from typing import Iterable import torch @@ -34,12 +35,9 @@ def configure_logger(self, logger): elif logger is False: self.logger = None else: - try: - _ = iter(logger) - # can call iter on logger, make it a logger list + if isinstance(logger, Iterable): self.logger = LoggerCollection(logger) - except TypeError: - # can't call iter, must just be a regular logger + else: self.logger = logger self.logger.rank = 0 From b1a88c52ff2ec70ce00a70854956e98a263cb85e Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Mon, 24 Feb 2020 14:05:11 +0000 Subject: [PATCH 08/12] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b96da2d8d891f..2ae3ebd0ced66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added Tensor Processing Unit (TPU) support ([#868](https://github.com/PyTorchLightning/pytorch-lightning/pull/868)) - Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876)) - Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849)) +- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) ### Changed From 195edd57654bdf4d08ac6f3d25f01ef168f9681b Mon Sep 17 00:00:00 2001 From: Ethan Haris Date: Tue, 25 Feb 2020 12:55:23 +0000 Subject: [PATCH 09/12] Update comet experiment type, Switch to abstractmethod in logging.py --- pytorch_lightning/loggers/base.py | 13 +++++++------ pytorch_lightning/loggers/comet.py | 3 ++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index aeabcb6adeaa8..237ca75a01c81 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -1,5 +1,5 @@ import argparse -from abc import ABC +from abc import ABC, abstractmethod from functools import wraps from typing import Union, Optional, Dict, Iterable, Any, Callable, List @@ -26,9 +26,11 @@ def __init__(self): self._rank = 0 @property + @abstractmethod def experiment(self) -> Any: - raise NotImplementedError() + """Return the experiment object associated with this logger""" + @abstractmethod def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): """Record metrics. @@ -36,15 +38,14 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ - raise NotImplementedError() + @abstractmethod def log_hyperparams(self, params: argparse.Namespace): """Record hyperparameters. Args: params: argparse.Namespace containing the hyperparameters """ - raise NotImplementedError() def save(self): """Save log data.""" @@ -70,14 +71,14 @@ def rank(self, value: int): self._rank = value @property + @abstractmethod def name(self) -> str: """Return the experiment name.""" - raise NotImplementedError("Sub-classes must provide a name property") @property + @abstractmethod def version(self) -> Union[int, str]: """Return the experiment version.""" - raise NotImplementedError("Sub-classes must provide a version property") class LoggerCollection(LightningLoggerBase): diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index cf56493a687bf..ce98a39c03453 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -13,6 +13,7 @@ from comet_ml import Experiment as CometExperiment from comet_ml import ExistingExperiment as CometExistingExperiment from comet_ml import OfflineExperiment as CometOfflineExperiment + from comet_ml import BaseExperiment as CometBaseExperiment try: from comet_ml.api import API except ImportError: @@ -122,7 +123,7 @@ def __init__(self, api_key: Optional[str] = None, save_dir: Optional[str] = None logger.exception("Failed to set experiment name for comet.ml logger") @property - def experiment(self) -> Union[CometOfflineExperiment, CometExistingExperiment, CometExperiment]: + def experiment(self) -> CometBaseExperiment: r""" Actual comet object. To use comet features do the following. From 2393a43d370331e01bef9ba6b2493b42288a6b4c Mon Sep 17 00:00:00 2001 From: Ethan Haris Date: Tue, 25 Feb 2020 13:06:10 +0000 Subject: [PATCH 10/12] Fix test --- tests/test_logging.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_logging.py b/tests/test_logging.py index f3942948a5b48..1ad487a75ad9d 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -337,6 +337,10 @@ def __init__(self): self.metrics_logged = None self.finalized = False + @property + def experiment(self): + return 'test' + @rank_zero_only def log_hyperparams(self, params): self.hparams_logged = params From 99e0a8abc481930157b6f3c5b41ae655aace182f Mon Sep 17 00:00:00 2001 From: Ethan Haris Date: Tue, 25 Feb 2020 13:38:10 +0000 Subject: [PATCH 11/12] Add passes to LightningLoggerBase --- pytorch_lightning/loggers/base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 237ca75a01c81..2bf24acad55ea 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -29,6 +29,7 @@ def __init__(self): @abstractmethod def experiment(self) -> Any: """Return the experiment object associated with this logger""" + pass @abstractmethod def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): @@ -38,6 +39,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """ + pass @abstractmethod def log_hyperparams(self, params: argparse.Namespace): @@ -46,9 +48,11 @@ def log_hyperparams(self, params: argparse.Namespace): Args: params: argparse.Namespace containing the hyperparameters """ + pass def save(self): """Save log data.""" + pass def finalize(self, status: str): """Do any processing that is necessary to finalize an experiment. @@ -56,9 +60,11 @@ def finalize(self, status: str): Args: status: Status that the experiment finished with (e.g. success, failed, aborted) """ + pass def close(self): """Do any cleanup that is necessary to close an experiment.""" + pass @property def rank(self) -> int: @@ -74,11 +80,13 @@ def rank(self, value: int): @abstractmethod def name(self) -> str: """Return the experiment name.""" + pass @property @abstractmethod def version(self) -> Union[int, str]: """Return the experiment version.""" + pass class LoggerCollection(LightningLoggerBase): From 63b62679e6d3c7933dc7863b724432476089445a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Tue, 25 Feb 2020 18:54:53 +0000 Subject: [PATCH 12/12] Update experiment_logging.rst --- docs/source/experiment_logging.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst index 853f2505e1df8..0ea8930026f05 100644 --- a/docs/source/experiment_logging.rst +++ b/docs/source/experiment_logging.rst @@ -137,3 +137,26 @@ The Wandb logger is available anywhere in your LightningModule some_img = fake_image() self.logger.experiment.add_image('generated_images', some_img, 0) + +Multiple Loggers +^^^^^^^^^^^^^^^^^ + +PyTorch-Lightning supports use of multiple loggers, just pass a list to the `Trainer`. + +.. code-block:: python + + from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger + + logger1 = TensorBoardLogger("tb_logs", name="my_model") + logger2 = TestTubeLogger("tt_logs", name="my_model") + trainer = Trainer(logger=[logger1, logger2]) + +The loggers are available as a list anywhere in your LightningModule + +.. code-block:: python + + class MyModule(pl.LightningModule): + + def __init__(self, ...): + some_img = fake_image() + self.logger.experiment[0].add_image('generated_images', some_img, 0)