Skip to content

Commit

Permalink
make loggers pickleable (#2518)
Browse files Browse the repository at this point in the history
* state updates to logger

* change log

* changelog
  • Loading branch information
awaelchli committed Jul 5, 2020
1 parent 6bfcfa8 commit 1098a0d
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed model summary input type conversion for models that have input dtype different from model parameters ([#2510](https://github.com/PyTorchLightning/pytorch-lightning/pull/2510))

- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518))

## [0.8.4] - 2020-07-01

### Added
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,15 @@ def __init__(self,
' install it with `pip install comet-ml`.')
super().__init__()
self._experiment = None
self._save_dir = save_dir

# Determine online or offline mode based on which arguments were passed to CometLogger
if api_key is not None:
self.mode = "online"
self.api_key = api_key
elif save_dir is not None:
self.mode = "offline"
self.save_dir = save_dir
self._save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")
Expand Down Expand Up @@ -219,6 +220,10 @@ def finalize(self, status: str) -> None:
self.experiment.end()
self.reset_experiment()

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> str:
return str(self.experiment.project_name)
Expand All @@ -230,3 +235,8 @@ def name(self, value: str) -> None:
@property
def version(self) -> str:
return self.experiment.id

def __getstate__(self):
state = self.__dict__.copy()
state["_experiment"] = None
return state
5 changes: 5 additions & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def finalize(self, status: str) -> None:
if self.close_after_fit:
self.experiment.stop()

@property
def save_dir(self) -> Optional[str]:
# Neptune does not save any local files
return None

@property
def name(self) -> str:
if self.offline_mode:
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
version: Optional[Union[int, str]] = None,
**kwargs):
super().__init__()
self.save_dir = save_dir
self._save_dir = save_dir
self._name = name
self._version = version

Expand Down Expand Up @@ -82,6 +82,10 @@ def log_dir(self) -> str:
log_dir = os.path.join(self.root_dir, version)
return log_dir

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
Expand Down Expand Up @@ -187,3 +191,8 @@ def _get_next_version(self):
return 0

return max(existing_versions) + 1

def __getstate__(self):
state = self.__dict__.copy()
state["_experiment"] = None
return state
10 changes: 9 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def log_metrics(self, metrics, step):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
# MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
Expand All @@ -93,6 +93,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)

# this can cause pickle error if the experiment object is not picklable
# the logger needs to remove it from the state before pickle
_ = logger.experiment

# test pickling loggers
pickle.dumps(logger)

Expand All @@ -105,6 +109,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})

# make sure we restord properly
assert trainer2.logger.name == logger.name
assert trainer2.logger.save_dir == logger.save_dir


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(max_epochs=1, auto_scale_batch_size=True), id='Batch-size-Finder'),
Expand Down
10 changes: 10 additions & 0 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from typing import Optional
from unittest.mock import MagicMock

import numpy as np
Expand Down Expand Up @@ -49,6 +50,15 @@ def log_metrics(self, metrics, step):
def finalize(self, status):
self.finalized_status = status

@property
def save_dir(self) -> Optional[str]:
"""
Return the root directory where experiment logs get saved, or `None` if the logger does not
save data locally.
"""
return None


@property
def name(self):
return "name"
Expand Down

0 comments on commit 1098a0d

Please sign in to comment.