Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Comet Logger pickled behavior #2553

Merged
merged 15 commits into from
Sep 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 95 additions & 51 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
-----
"""

import os

from argparse import Namespace
from typing import Optional, Dict, Union, Any

Expand All @@ -11,12 +13,13 @@
from comet_ml import ExistingExperiment as CometExistingExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml import BaseExperiment as CometBaseExperiment
from comet_ml import generate_guid

try:
from comet_ml.api import API
except ImportError: # pragma: no-cover
# For more information, see: https://www.comet.ml/docs/python-sdk/releases/#release-300
from comet_ml.papi import API # pragma: no-cover

_COMET_AVAILABLE = True
except ImportError: # pragma: no-cover
CometExperiment = None
Expand All @@ -25,7 +28,7 @@
CometBaseExperiment = None
API = None
_COMET_AVAILABLE = False

generate_guid = None
Lothiraldan marked this conversation as resolved.
Show resolved Hide resolved

import torch
from torch import is_tensor
Expand Down Expand Up @@ -90,19 +93,23 @@ class CometLogger(LightningLoggerBase):
experiment_key: Optional. If set, restores from existing experiment.
"""

def __init__(self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
**kwargs):
def __init__(
self,
api_key: Optional[str] = None,
save_dir: Optional[str] = None,
workspace: Optional[str] = None,
project_name: Optional[str] = None,
rest_api_key: Optional[str] = None,
experiment_name: Optional[str] = None,
experiment_key: Optional[str] = None,
**kwargs,
):

if not _COMET_AVAILABLE:
raise ImportError('You want to use `comet_ml` logger which is not installed yet,'
' install it with `pip install comet-ml`.')
raise ImportError(
"You want to use `comet_ml` logger which is not installed yet,"
" install it with `pip install comet-ml`."
)
super().__init__()
self._experiment = None
self._save_dir = save_dir
Expand All @@ -121,9 +128,11 @@ def __init__(self,
log.info(f"CometLogger will be initialized in {self.mode} mode")

self.workspace = workspace
self.project_name = project_name
self.experiment_key = experiment_key
self._project_name = project_name
self._experiment_key = experiment_key
self._experiment_name = experiment_name
self._kwargs = kwargs
self._future_experiment_key = None

if rest_api_key is not None:
# Comet.ml rest API, used to determine version number
Expand All @@ -133,11 +142,6 @@ def __init__(self,
self.rest_api_key = None
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError:
log.exception("Failed to set experiment name for comet.ml logger")
self._kwargs = kwargs

@property
Expand All @@ -155,30 +159,40 @@ def experiment(self) -> CometBaseExperiment:
if self._experiment is not None:
return self._experiment

if self.mode == "online":
if self.experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
self.experiment_key = self._experiment.get_key()
if self._future_experiment_key is not None:
os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key
self._future_experiment_key = None

try:
Borda marked this conversation as resolved.
Show resolved Hide resolved
if self.mode == "online":
if self._experiment_key is None:
self._experiment = CometExperiment(
api_key=self.api_key, workspace=self.workspace, project_name=self._project_name, **self._kwargs
)
self._experiment_key = self._experiment.get_key()
else:
self._experiment = CometExistingExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self._project_name,
previous_experiment=self._experiment_key,
**self._kwargs,
)
else:
self._experiment = CometExistingExperiment(
api_key=self.api_key,
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
previous_experiment=self.experiment_key,
**self._kwargs
project_name=self._project_name,
**self._kwargs,
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
finally:
try:
del os.environ["COMET_EXPERIMENT_KEY"]
except KeyError:
pass
Lothiraldan marked this conversation as resolved.
Show resolved Hide resolved

if self._experiment_name:
self._experiment.set_name(self._experiment_name)

return self._experiment

Expand All @@ -189,13 +203,8 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
self.experiment.log_parameters(params)

@rank_zero_only
def log_metrics(
self,
metrics: Dict[str, Union[torch.Tensor, float]],
step: Optional[int] = None
) -> None:
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'

def log_metrics(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None) -> None:
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
Expand Down Expand Up @@ -226,17 +235,52 @@ def save_dir(self) -> Optional[str]:

@property
def name(self) -> str:
return str(self.experiment.project_name)
# don't create an experiment if we don't have one
Lothiraldan marked this conversation as resolved.
Show resolved Hide resolved
if self._experiment is not None and self._experiment.project_name is not None:
return self._experiment.project_name

if self._project_name is not None:
return self._project_name

return "comet-default"

@name.setter
def name(self, value: str) -> None:
self.experiment.set_name(value)
self._experiment_name = value

# Only set the experiment object name if it already exists as we don't
# want to create an experiment object as soon as we create a Comet
# Logger
if self._experiment is not None:
self._experiment.set_name(value)

@property
def version(self) -> str:
return self.experiment.id
# Don't create an experiment if we don't have one
if self._experiment is not None:
return self._experiment.id

if self._experiment_key is not None:
return self._experiment_key

if self._future_experiment_key is not None:
return self._future_experiment_key

# Pre-generate an experiment key
self._future_experiment_key = generate_guid()

return self._future_experiment_key

def __getstate__(self):
state = self.__dict__.copy()

# Save the experiment id in case an experiment object already exists,
# this way we could create an ExistingExperiment poiting to the same
Lothiraldan marked this conversation as resolved.
Show resolved Hide resolved
# experiment
state["_experiment_key"] = self._experiment.id if self._experiment is not None else None

# Remove the experiment object as it contains hard to pickle objects
# (like network connections), the experiment object will be recreated if
# needed later
state["_experiment"] = None
return state
64 changes: 28 additions & 36 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,23 @@ def test_comet_logger_online():
"""Test comet online with mocks."""
# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general'
)
logger = CometLogger(api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test both given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(
save_dir='test',
api_key='key',
workspace='dummy-test',
project_name='general'
)
logger = CometLogger(save_dir='test', api_key='key', workspace='dummy-test', project_name='general')

_ = logger.experiment

comet.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general'
)
comet.assert_called_once_with(api_key='key', workspace='dummy-test', project_name='general')

# Test neither given
with pytest.raises(MisconfigurationException):
CometLogger(
workspace='dummy-test',
project_name='general'
)
CometLogger(workspace='dummy-test', project_name='general')

# Test already exists
with patch('pytorch_lightning.loggers.comet.CometExistingExperiment') as comet_existing:
Expand All @@ -55,26 +35,38 @@ def test_comet_logger_online():
experiment_name='experiment',
api_key='key',
workspace='dummy-test',
project_name='general'
project_name='general',
)

_ = logger.experiment

comet_existing.assert_called_once_with(
api_key='key',
workspace='dummy-test',
project_name='general',
previous_experiment='test'
api_key='key', workspace='dummy-test', project_name='general', previous_experiment='test'
)

comet_existing().set_name.assert_called_once_with('experiment')

with patch('pytorch_lightning.loggers.comet.API') as api:
CometLogger(
api_key='key',
workspace='dummy-test',
project_name='general',
rest_api_key='rest'
)
CometLogger(api_key='key', workspace='dummy-test', project_name='general', rest_api_key='rest')

api.assert_called_once_with('rest')


def test_comet_logger_experiment_name():
"""Test that Comet Logger experiment name works correctly."""

api_key = "key"
experiment_name = "My Name"

# Test api_key given
with patch('pytorch_lightning.loggers.comet.CometExperiment') as comet:
logger = CometLogger(api_key=api_key, experiment_name=experiment_name,)

# The experiment object should not exists
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remove these type of comments, also in the other places below? the assertion below makes it very clear

Lothiraldan marked this conversation as resolved.
Show resolved Hide resolved
assert logger._experiment is None

_ = logger.experiment

comet.assert_called_once_with(api_key=api_key, project_name=None, workspace=None)

comet().set_name.assert_called_once_with(experiment_name)