Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make loggers pickleable #2518

Merged
merged 3 commits into from
Jul 5, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,15 @@ def __init__(self,
' install it with `pip install comet-ml`.')
super().__init__()
self._experiment = None
self._save_dir = save_dir

# Determine online or offline mode based on which arguments were passed to CometLogger
if api_key is not None:
self.mode = "online"
self.api_key = api_key
elif save_dir is not None:
self.mode = "offline"
self.save_dir = save_dir
self._save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")
Expand Down Expand Up @@ -219,6 +220,10 @@ def finalize(self, status: str) -> None:
self.experiment.end()
self.reset_experiment()

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> str:
return str(self.experiment.project_name)
Expand All @@ -230,3 +235,8 @@ def name(self, value: str) -> None:
@property
def version(self) -> str:
return self.experiment.id

def __getstate__(self):
state = self.__dict__.copy()
state["_experiment"] = None
return state
5 changes: 5 additions & 0 deletions pytorch_lightning/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,11 @@ def finalize(self, status: str) -> None:
if self.close_after_fit:
self.experiment.stop()

@property
def save_dir(self) -> Optional[str]:
# Neptune does not save any local files
return None

@property
def name(self) -> str:
if self.offline_mode:
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self,
version: Optional[Union[int, str]] = None,
**kwargs):
super().__init__()
self.save_dir = save_dir
self._save_dir = save_dir
self._name = name
self._version = version

Expand Down Expand Up @@ -82,6 +82,10 @@ def log_dir(self) -> str:
log_dir = os.path.join(self.root_dir, version)
return log_dir

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
@rank_zero_experiment
def experiment(self) -> SummaryWriter:
Expand Down Expand Up @@ -187,3 +191,8 @@ def _get_next_version(self):
return 0

return max(existing_versions) + 1

def __getstate__(self):
state = self.__dict__.copy()
state["_experiment"] = None
return state
10 changes: 9 additions & 1 deletion tests/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def log_metrics(self, metrics, step):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
# MLFlowLogger,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fixing this logger in #2502 so for now, it will not be picklable.

NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
Expand All @@ -93,6 +93,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)

# this can cause pickle error if the experiment object is not picklable
# the logger needs to remove it from the state before pickle
_ = logger.experiment

# test pickling loggers
pickle.dumps(logger)

Expand All @@ -105,6 +109,10 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({'acc': 1.0})

# make sure we restord properly
assert trainer2.logger.name == logger.name
assert trainer2.logger.save_dir == logger.save_dir


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(max_epochs=1, auto_scale_batch_size=True), id='Batch-size-Finder'),
Expand Down
10 changes: 10 additions & 0 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pickle
from typing import Optional
from unittest.mock import MagicMock

import numpy as np
Expand Down Expand Up @@ -49,6 +50,15 @@ def log_metrics(self, metrics, step):
def finalize(self, status):
self.finalized_status = status

@property
def save_dir(self) -> Optional[str]:
"""
Return the root directory where experiment logs get saved, or `None` if the logger does not
save data locally.
"""
return None


@property
def name(self):
return "name"
Expand Down