Skip to content

Commit

Permalink
save_dir fix for MLflowLogger + save_dir tests for others (#2502)
Browse files Browse the repository at this point in the history
* mlflow rework

* logger save_dir

* folder

* mlflow

* simplify

* fix test

* add a test for file dir contents

* new line

* changelog

* docs

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* test for comet logger

* improve mlflow checkpoint test

* prevent  commet logger error on pytest exit

* test tensorboard save dir structure

* wandb save dir test

* skip test on windows

* add mlflow to pickle tests

* wandb

* code factor

* remove unused imports

* remove unused setter

* wandb mock

* wip mock

* wip mock

* wandb tests with mocking

* clean up

* clean up

* comments

* include wandblogger in test

* clean up

* missing argument

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
awaelchli and Borda authored Jul 9, 2020
1 parent 992a7e2 commit f16b4cf
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 84 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518))

- Fixed a problem with `MLflowLogger` creating multiple run folders ([#2502](https://github.com/PyTorchLightning/pytorch-lightning/pull/2502))

## [0.8.4] - 2020-07-01

### Added
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,9 @@ def on_train_start(self, trainer, pl_module):

if trainer.logger is not None:
# weights_save_path overrides anything
if getattr(trainer, 'weights_save_path', None) is not None:
save_dir = trainer.weights_save_path
else:
save_dir = (getattr(trainer.logger, 'save_dir', None)
or getattr(trainer.logger, '_save_dir', None)
or trainer.default_root_dir)
save_dir = (getattr(trainer, 'weights_save_path', None)
or getattr(trainer.logger, 'save_dir', None)
or trainer.default_root_dir)

version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
self.save()

@property
def save_dir(self) -> Optional[str]:
"""
Return the root directory where experiment logs get saved, or `None` if the logger does not
save data locally.
"""
return None

@property
@abstractmethod
def name(self) -> str:
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def __init__(self,
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError:
log.exception("Failed to set experiment name for comet.ml logger")
self.experiment.set_name(experiment_name)
self._kwargs = kwargs

@property
Expand Down Expand Up @@ -228,10 +225,6 @@ def save_dir(self) -> Optional[str]:
def name(self) -> str:
return str(self.experiment.project_name)

@name.setter
def name(self, value: str) -> None:
self.experiment.set_name(value)

@property
def version(self) -> str:
return self.experiment.id
Expand Down
81 changes: 54 additions & 27 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
MLflow
------
"""
import os
from argparse import Namespace
from time import time
from typing import Optional, Dict, Any, Union
Expand All @@ -11,16 +10,20 @@
import mlflow
from mlflow.tracking import MlflowClient
_MLFLOW_AVAILABLE = True
except ImportError: # pragma: no-cover
except ModuleNotFoundError: # pragma: no-cover
mlflow = None
MlflowClient = None
_MLFLOW_AVAILABLE = False


from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


LOCAL_FILE_URI_PREFIX = "file:"


class MLFlowLogger(LightningLoggerBase):
"""
Log using `MLflow <https://mlflow.org>`_. Install it with pip:
Expand Down Expand Up @@ -52,59 +55,71 @@ class MLFlowLogger(LightningLoggerBase):
Args:
experiment_name: The name of the experiment
tracking_uri: Address of local or remote tracking server.
If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``.
If not provided, defaults to `file:<save_dir>`.
tags: A dictionary tags for the experiment.
save_dir: A path to a local directory where the MLflow runs get saved.
Defaults to `./mlflow` if `tracking_uri` is not provided.
Has no effect if `tracking_uri` is provided.
"""

def __init__(self,
experiment_name: str = 'default',
tracking_uri: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = None):
save_dir: Optional[str] = './mlruns'):

if not _MLFLOW_AVAILABLE:
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
' install it with `pip install mlflow`.')
super().__init__()
if not tracking_uri and save_dir:
tracking_uri = f'file:{os.sep * 2}{save_dir}'
self._mlflow_client = MlflowClient(tracking_uri)
self.experiment_name = experiment_name
if not tracking_uri:
tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}'

self._experiment_name = experiment_name
self._experiment_id = None
self._tracking_uri = tracking_uri
self._run_id = None
self.tags = tags
self._mlflow_client = MlflowClient(tracking_uri)

@property
@rank_zero_experiment
def experiment(self) -> MlflowClient:
r"""
Actual MLflow object. To use mlflow features in your
Actual MLflow object. To use MLflow features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
Example::
self.logger.experiment.some_mlflow_function()
"""
return self._mlflow_client

@property
def run_id(self):
if self._run_id is not None:
return self._run_id

expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)
expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)

if expt:
self._expt_id = expt.experiment_id
self._experiment_id = expt.experiment_id
else:
log.warning(f'Experiment with name {self.experiment_name} not found. Creating it.')
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)
log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.')
self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name)

run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
self._run_id = run.info.run_id
if not self._run_id:
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)
self._run_id = run.info.run_id
return self._mlflow_client

@property
def run_id(self):
# create the experiment if it does not exist to get the run id
_ = self.experiment
return self._run_id

@property
def experiment_id(self):
# create the experiment if it does not exist to get the experiment id
_ = self.experiment
return self._experiment_id

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
Expand All @@ -126,14 +141,26 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
@rank_zero_only
def finalize(self, status: str = 'FINISHED') -> None:
super().finalize(status)
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)
status = 'FINISHED' if status == 'success' else status
if self.experiment.get_run(self.run_id):
self.experiment.set_terminated(self.run_id, status)

@property
def save_dir(self) -> Optional[str]:
"""
The root file directory in which MLflow experiments are saved.
Return:
Local path to the root experiment directory if the tracking uri is local.
Otherwhise returns `None`.
"""
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

@property
def name(self) -> str:
return self.experiment_name
return self.experiment_id

@property
def version(self) -> str:
return self._run_id
return self.run_id
6 changes: 1 addition & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
**kwargs):
super().__init__()
self._save_dir = save_dir
self._name = name
self._name = name or ''
self._version = version

self._experiment = None
Expand Down Expand Up @@ -106,10 +106,6 @@ def experiment(self) -> SummaryWriter:
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

@experiment.setter
def experiment(self, exp):
self._experiment = exp

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None) -> None:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
' install it with `pip install test-tube`.')
super().__init__()
self.save_dir = save_dir
self._save_dir = save_dir
self._name = name
self.description = description
self.debug = debug
Expand Down Expand Up @@ -141,6 +141,10 @@ def close(self) -> None:
exp = self.experiment
exp.close()

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> str:
if self._experiment is None:
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def experiment(self) -> Run:
group=self._group)
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self.save_dir = self._experiment.dir
self._save_dir = self._experiment.dir
return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -134,13 +134,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->

self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> Optional[str]:
# don't create an experiment if we don't have one
name = self._experiment.project_name() if self._experiment else None
return name
return self._experiment.project_name() if self._experiment else self._name

@property
def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else None
return self._experiment.id if self._experiment else self._id
43 changes: 28 additions & 15 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import atexit
import inspect
import pickle
import platform
from unittest import mock

import pytest

Expand Down Expand Up @@ -35,14 +37,15 @@ def _get_logger_args(logger_class, save_dir):
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
WandbLogger,
])
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

model = EvalModelTemplate()

Expand All @@ -58,6 +61,11 @@ def log_metrics(self, metrics, step):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = StoreHistoryLogger(**logger_args)

if logger_class == WandbLogger:
# required mocks for Trainer
logger.experiment.id = 'foo'
logger.experiment.project_name.return_value = 'bar'

trainer = Trainer(
max_epochs=1,
logger=logger,
Expand All @@ -66,7 +74,6 @@ def log_metrics(self, metrics, step):
fast_dev_run=True,
)
trainer.fit(model)

trainer.test()

log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
Expand All @@ -78,17 +85,17 @@ def log_metrics(self, metrics, step):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
# MLFlowLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
# The WandbLogger gets tested for pickling in its own test.
])
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
"""Verify that pickling trainer with logger works."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
Expand Down Expand Up @@ -156,12 +163,18 @@ def on_batch_start(self, trainer, pl_module):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
# MLFlowLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
WandbLogger,
])
def test_logger_created_on_rank_zero_only(tmpdir, logger_class):
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
""" Test that loggers get replaced by dummy logges on global rank > 0"""
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
model = EvalModelTemplate()
Expand Down
Loading

0 comments on commit f16b4cf

Please sign in to comment.