Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comet fix #481

Merged
merged 8 commits into from
Nov 12, 2019
Merged
1 change: 1 addition & 0 deletions .run_local_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
rm -rf _ckpt_*
rm -rf tests/save_dir*
rm -rf tests/mlruns_*
rm -rf tests/cometruns*
rm -rf tests/tests/*
rm -rf lightning_logs
coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules
Expand Down
20 changes: 19 additions & 1 deletion docs/Trainer/Logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,30 @@ def any_lightning_module_function_or_hook(...):

Log using [comet](https://www.comet.ml)

Comet logger can be used in either online or offline mode.
To log in online mode, CometLogger requries an API key:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
api_key=os.environ["COMET_KEY"],
workspace=os.environ["COMET_KEY"],
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
To log in offline mode, CometLogger requires a path to a local directory:
```{.python}
from pytorch_lightning.logging import CometLogger
# arguments made to CometLogger are passed on to the comet_ml.Experiment class
comet_logger = CometLogger(
save_dir=".",
workspace=os.environ["COMET_WORKSPACE"], # Optional
project_name="default_project", # Optional
rest_api_key=os.environ["COMET_REST_KEY"], # Optional
experiment_name="default" # Optional
)
trainer = Trainer(logger=comet_logger)
```
Expand Down
113 changes: 107 additions & 6 deletions pytorch_lightning/logging/comet_logger.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,126 @@
from logging import getLogger

try:
from comet_ml import Experiment as CometExperiment
from comet_ml import OfflineExperiment as CometOfflineExperiment
from comet_ml.papi import API
except ImportError:
raise ImportError('Missing comet_ml package.')

from torch import is_tensor

from .base import LightningLoggerBase, rank_zero_only
from ..utilities.debugging import MisconfigurationException

logger = getLogger(__name__)


class CometLogger(LightningLoggerBase):
def __init__(self, *args, **kwargs):
super(CometLogger, self).__init__()
self.experiment = CometExperiment(*args, **kwargs)
def __init__(self, api_key=None, save_dir=None, workspace=None,
rest_api_key=None, project_name=None, experiment_name=None, **kwargs):
"""
Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode)

:param str api_key: Required in online mode. API key, found on Comet.ml
:param str save_dir: Required in offline mode. The path for the directory to save local comet logs
:param str workspace: Optional. Name of workspace for this user
:param str project_name: Optional. Send your experiment to a specific project.
Otherwise will be sent to Uncategorized Experiments.
If project name does not already exists Comet.ml will create a new project.
:param str rest_api_key: Optional. Rest API key found in Comet.ml settings.
This is used to determine version number
:param str experiment_name: Optional. String representing the name for this particular experiment on Comet.ml
"""
super().__init__()
self._experiment = None

# Determine online or offline mode based on which arguments were passed to CometLogger
if save_dir is not None and api_key is not None:
# If arguments are passed for both save_dir and api_key, preference is given to online mode
self.mode = "online"
self.api_key = api_key
elif api_key is not None:
self.mode = "online"
self.api_key = api_key
elif save_dir is not None:
self.mode = "offline"
self.save_dir = save_dir
else:
# If neither api_key nor save_dir are passed as arguments, raise an exception
raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.")

logger.info(f"CometLogger will be initialized in {self.mode} mode")

self.workspace = workspace
self.project_name = project_name
self._kwargs = kwargs

if rest_api_key is not None:
rwesterman marked this conversation as resolved.
Show resolved Hide resolved
# Comet.ml rest API, used to determine version number
self.rest_api_key = rest_api_key
self.comet_api = API(self.rest_api_key)
else:
self.rest_api_key = None
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError as e:
logger.exception("Failed to set experiment name for comet.ml logger")

@property
def experiment(self):
if self._experiment is not None:
return self._experiment

if self.mode == "online":
self._experiment = CometExperiment(
api_key=self.api_key,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)
else:
self._experiment = CometOfflineExperiment(
offline_directory=self.save_dir,
workspace=self.workspace,
project_name=self.project_name,
**self._kwargs
)

return self._experiment

@rank_zero_only
def log_hyperparams(self, params):
self.experiment.log_parameters(vars(params))

@rank_zero_only
def log_metrics(self, metrics, step_num):
# self.experiment.set_epoch(self, metrics.get('epoch', 0))
self.experiment.log_metrics(metrics)
def log_metrics(self, metrics, step_num=None):
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
for key, val in metrics.items():
if is_tensor(val):
metrics[key] = val.cpu().detach()

self.experiment.log_metrics(metrics, step=step_num)

@rank_zero_only
def finalize(self, status):
self.experiment.end()

@property
def name(self):
return self.experiment.project_name

@name.setter
def name(self, value):
self.experiment.set_name(value)

@property
def version(self):
if self.project_name and self.rest_api_key:
# Determines the number of experiments in this project, and returns the next integer as the version number
nb_exps = len(self.comet_api.get_experiments(self.workspace, self.project_name))
return nb_exps + 1
else:
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if returning None as a version is a good idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I can tell, logger.version is only being used in the tqdm progress bar. If version is None, it is ignored, otherwise the version is displayed in the progress bar.

If there are future plans to use this version elsewhere, I'm open to suggestions for a different default return value.

77 changes: 76 additions & 1 deletion tests/test_y_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,83 @@ def test_mlflow_pickle():
testing_utils.clear_save_dir()


def test_custom_logger(tmpdir):
def test_comet_logger():
"""
verify that basic functionality of Comet.ml logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name="general",
workspace="dummy-test",
)

trainer_options = dict(
max_nb_epochs=1,
train_percent_check=0.01,
logger=logger
)

trainer = Trainer(**trainer_options)
result = trainer.fit(model)

print('result finished')
assert result == 1, "Training failed"

testing_utils.clear_save_dir()


def test_comet_pickle():
"""
verify that pickling trainer with comet logger works
"""
reset_seed()

try:
from pytorch_lightning.logging import CometLogger
except ModuleNotFoundError:
return

hparams = testing_utils.get_hparams()
model = LightningTestModel(hparams)

root_dir = os.path.dirname(os.path.realpath(__file__))
comet_dir = os.path.join(root_dir, "cometruns")

# We test CometLogger in offline mode with local saves
logger = CometLogger(
save_dir=comet_dir,
project_name="general",
workspace="dummy-test",
)

trainer_options = dict(
max_nb_epochs=1,
logger=logger
)

trainer = Trainer(**trainer_options)
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer2.logger.log_metrics({"acc": 1.0})

testing_utils.clear_save_dir()


def test_custom_logger(tmpdir):
class CustomLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
Expand Down