diff --git a/.run_local_tests.sh b/.run_local_tests.sh index 57d40c26b0ee5..0b06d104f5f70 100644 --- a/.run_local_tests.sh +++ b/.run_local_tests.sh @@ -2,6 +2,7 @@ rm -rf _ckpt_* rm -rf tests/save_dir* rm -rf tests/mlruns_* +rm -rf tests/cometruns* rm -rf tests/tests/* rm -rf lightning_logs coverage run --source pytorch_lightning -m py.test pytorch_lightning tests pl_examples -v --doctest-modules diff --git a/docs/Trainer/Logging.md b/docs/Trainer/Logging.md index 5b44384711ea2..7dc0fde355e5b 100644 --- a/docs/Trainer/Logging.md +++ b/docs/Trainer/Logging.md @@ -87,12 +87,30 @@ def any_lightning_module_function_or_hook(...): Log using [comet](https://www.comet.ml) +Comet logger can be used in either online or offline mode. +To log in online mode, CometLogger requries an API key: ```{.python} from pytorch_lightning.logging import CometLogger # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( api_key=os.environ["COMET_KEY"], - workspace=os.environ["COMET_KEY"], + workspace=os.environ["COMET_WORKSPACE"], # Optional + project_name="default_project", # Optional + rest_api_key=os.environ["COMET_REST_KEY"], # Optional + experiment_name="default" # Optional +) +trainer = Trainer(logger=comet_logger) +``` +To log in offline mode, CometLogger requires a path to a local directory: +```{.python} +from pytorch_lightning.logging import CometLogger +# arguments made to CometLogger are passed on to the comet_ml.Experiment class +comet_logger = CometLogger( + save_dir=".", + workspace=os.environ["COMET_WORKSPACE"], # Optional + project_name="default_project", # Optional + rest_api_key=os.environ["COMET_REST_KEY"], # Optional + experiment_name="default" # Optional ) trainer = Trainer(logger=comet_logger) ``` diff --git a/pytorch_lightning/logging/comet_logger.py b/pytorch_lightning/logging/comet_logger.py index 5e1d281c8d9a2..21dad421b7520 100644 --- a/pytorch_lightning/logging/comet_logger.py +++ b/pytorch_lightning/logging/comet_logger.py @@ -1,25 +1,126 @@ +from logging import getLogger + try: from comet_ml import Experiment as CometExperiment + from comet_ml import OfflineExperiment as CometOfflineExperiment + from comet_ml.papi import API except ImportError: raise ImportError('Missing comet_ml package.') +from torch import is_tensor + from .base import LightningLoggerBase, rank_zero_only +from ..utilities.debugging import MisconfigurationException + +logger = getLogger(__name__) class CometLogger(LightningLoggerBase): - def __init__(self, *args, **kwargs): - super(CometLogger, self).__init__() - self.experiment = CometExperiment(*args, **kwargs) + def __init__(self, api_key=None, save_dir=None, workspace=None, + rest_api_key=None, project_name=None, experiment_name=None, **kwargs): + """ + Initialize a Comet.ml logger. Requires either an API Key (online mode) or a local directory path (offline mode) + + :param str api_key: Required in online mode. API key, found on Comet.ml + :param str save_dir: Required in offline mode. The path for the directory to save local comet logs + :param str workspace: Optional. Name of workspace for this user + :param str project_name: Optional. Send your experiment to a specific project. + Otherwise will be sent to Uncategorized Experiments. + If project name does not already exists Comet.ml will create a new project. + :param str rest_api_key: Optional. Rest API key found in Comet.ml settings. + This is used to determine version number + :param str experiment_name: Optional. String representing the name for this particular experiment on Comet.ml + """ + super().__init__() + self._experiment = None + + # Determine online or offline mode based on which arguments were passed to CometLogger + if save_dir is not None and api_key is not None: + # If arguments are passed for both save_dir and api_key, preference is given to online mode + self.mode = "online" + self.api_key = api_key + elif api_key is not None: + self.mode = "online" + self.api_key = api_key + elif save_dir is not None: + self.mode = "offline" + self.save_dir = save_dir + else: + # If neither api_key nor save_dir are passed as arguments, raise an exception + raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") + + logger.info(f"CometLogger will be initialized in {self.mode} mode") + + self.workspace = workspace + self.project_name = project_name + self._kwargs = kwargs + + if rest_api_key is not None: + # Comet.ml rest API, used to determine version number + self.rest_api_key = rest_api_key + self.comet_api = API(self.rest_api_key) + else: + self.rest_api_key = None + self.comet_api = None + + if experiment_name: + try: + self.name = experiment_name + except TypeError as e: + logger.exception("Failed to set experiment name for comet.ml logger") + + @property + def experiment(self): + if self._experiment is not None: + return self._experiment + + if self.mode == "online": + self._experiment = CometExperiment( + api_key=self.api_key, + workspace=self.workspace, + project_name=self.project_name, + **self._kwargs + ) + else: + self._experiment = CometOfflineExperiment( + offline_directory=self.save_dir, + workspace=self.workspace, + project_name=self.project_name, + **self._kwargs + ) + + return self._experiment @rank_zero_only def log_hyperparams(self, params): self.experiment.log_parameters(vars(params)) @rank_zero_only - def log_metrics(self, metrics, step_num): - # self.experiment.set_epoch(self, metrics.get('epoch', 0)) - self.experiment.log_metrics(metrics) + def log_metrics(self, metrics, step_num=None): + # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + for key, val in metrics.items(): + if is_tensor(val): + metrics[key] = val.cpu().detach() + + self.experiment.log_metrics(metrics, step=step_num) @rank_zero_only def finalize(self, status): self.experiment.end() + + @property + def name(self): + return self.experiment.project_name + + @name.setter + def name(self, value): + self.experiment.set_name(value) + + @property + def version(self): + if self.project_name and self.rest_api_key: + # Determines the number of experiments in this project, and returns the next integer as the version number + nb_exps = len(self.comet_api.get_experiments(self.workspace, self.project_name)) + return nb_exps + 1 + else: + return None diff --git a/tests/test_y_logging.py b/tests/test_y_logging.py index c44a9e2c8b987..cfa979956c6b0 100644 --- a/tests/test_y_logging.py +++ b/tests/test_y_logging.py @@ -137,8 +137,83 @@ def test_mlflow_pickle(): testing_utils.clear_save_dir() -def test_custom_logger(tmpdir): +def test_comet_logger(): + """ + verify that basic functionality of Comet.ml logger works + """ + reset_seed() + + try: + from pytorch_lightning.logging import CometLogger + except ModuleNotFoundError: + return + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + root_dir = os.path.dirname(os.path.realpath(__file__)) + comet_dir = os.path.join(root_dir, "cometruns") + + # We test CometLogger in offline mode with local saves + logger = CometLogger( + save_dir=comet_dir, + project_name="general", + workspace="dummy-test", + ) + + trainer_options = dict( + max_nb_epochs=1, + train_percent_check=0.01, + logger=logger + ) + + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + print('result finished') + assert result == 1, "Training failed" + + testing_utils.clear_save_dir() + +def test_comet_pickle(): + """ + verify that pickling trainer with comet logger works + """ + reset_seed() + + try: + from pytorch_lightning.logging import CometLogger + except ModuleNotFoundError: + return + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + root_dir = os.path.dirname(os.path.realpath(__file__)) + comet_dir = os.path.join(root_dir, "cometruns") + + # We test CometLogger in offline mode with local saves + logger = CometLogger( + save_dir=comet_dir, + project_name="general", + workspace="dummy-test", + ) + + trainer_options = dict( + max_nb_epochs=1, + logger=logger + ) + + trainer = Trainer(**trainer_options) + pkl_bytes = pickle.dumps(trainer) + trainer2 = pickle.loads(pkl_bytes) + trainer2.logger.log_metrics({"acc": 1.0}) + + testing_utils.clear_save_dir() + + +def test_custom_logger(tmpdir): class CustomLogger(LightningLoggerBase): def __init__(self): super().__init__()