diff --git a/CHANGELOG.md b/CHANGELOG.md index 43471b71ec23a..eb9fdfbc5ba06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed model summary input type conversion for models that have input dtype different from model parameters ([#2510](https://github.com/PyTorchLightning/pytorch-lightning/pull/2510)) +- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518)) + ## [0.8.4] - 2020-07-01 ### Added diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index a7d783c36d195..5c5cc29fa7778 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -105,6 +105,7 @@ def __init__(self, ' install it with `pip install comet-ml`.') super().__init__() self._experiment = None + self._save_dir = save_dir # Determine online or offline mode based on which arguments were passed to CometLogger if api_key is not None: @@ -112,7 +113,7 @@ def __init__(self, self.api_key = api_key elif save_dir is not None: self.mode = "offline" - self.save_dir = save_dir + self._save_dir = save_dir else: # If neither api_key nor save_dir are passed as arguments, raise an exception raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") @@ -219,6 +220,10 @@ def finalize(self, status: str) -> None: self.experiment.end() self.reset_experiment() + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def name(self) -> str: return str(self.experiment.project_name) @@ -230,3 +235,8 @@ def name(self, value: str) -> None: @property def version(self) -> str: return self.experiment.id + + def __getstate__(self): + state = self.__dict__.copy() + state["_experiment"] = None + return state diff --git a/pytorch_lightning/loggers/neptune.py b/pytorch_lightning/loggers/neptune.py index 45d4f65f484af..cb88dceee34eb 100644 --- a/pytorch_lightning/loggers/neptune.py +++ b/pytorch_lightning/loggers/neptune.py @@ -259,6 +259,11 @@ def finalize(self, status: str) -> None: if self.close_after_fit: self.experiment.stop() + @property + def save_dir(self) -> Optional[str]: + # Neptune does not save any local files + return None + @property def name(self) -> str: if self.offline_mode: diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index edb136f538df5..61da82ac7731b 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -50,7 +50,7 @@ def __init__(self, version: Optional[Union[int, str]] = None, **kwargs): super().__init__() - self.save_dir = save_dir + self._save_dir = save_dir self._name = name self._version = version @@ -82,6 +82,10 @@ def log_dir(self) -> str: log_dir = os.path.join(self.root_dir, version) return log_dir + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property @rank_zero_experiment def experiment(self) -> SummaryWriter: @@ -187,3 +191,8 @@ def _get_next_version(self): return 0 return max(existing_versions) + 1 + + def __getstate__(self): + state = self.__dict__.copy() + state["_experiment"] = None + return state diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index fdd172325d4d1..094bcbf1956f6 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -78,7 +78,7 @@ def log_metrics(self, metrics, step): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, - MLFlowLogger, + # MLFlowLogger, NeptuneLogger, TestTubeLogger, # WandbLogger, # TODO: add this one @@ -93,6 +93,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) + # this can cause pickle error if the experiment object is not picklable + # the logger needs to remove it from the state before pickle + _ = logger.experiment + # test pickling loggers pickle.dumps(logger) @@ -105,6 +109,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class): trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({'acc': 1.0}) + # make sure we restord properly + assert trainer2.logger.name == logger.name + assert trainer2.logger.save_dir == logger.save_dir + @pytest.mark.parametrize("extra_params", [ pytest.param(dict(max_epochs=1, auto_scale_batch_size=True), id='Batch-size-Finder'), diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index dfe9ffc6437fe..085368af105ef 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -1,4 +1,5 @@ import pickle +from typing import Optional from unittest.mock import MagicMock import numpy as np @@ -49,6 +50,15 @@ def log_metrics(self, metrics, step): def finalize(self, status): self.finalized_status = status + @property + def save_dir(self) -> Optional[str]: + """ + Return the root directory where experiment logs get saved, or `None` if the logger does not + save data locally. + """ + return None + + @property def name(self): return "name"