Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

save_dir fix for MLflowLogger + save_dir tests for others #2502

Merged
merged 37 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0c0d564
mlflow rework
Jul 4, 2020
2d7223f
logger save_dir
Jul 4, 2020
33a93c2
folder
Jul 4, 2020
5699e82
mlflow
Jul 4, 2020
75a50ff
simplify
Jul 4, 2020
7c1bcc4
fix test
Jul 4, 2020
4a5d81a
add a test for file dir contents
Jul 4, 2020
0ae884d
new line
Jul 4, 2020
4bcc117
changelog
Jul 4, 2020
cbaeb02
docs
Jul 4, 2020
a4279af
Update CHANGELOG.md
awaelchli Jul 4, 2020
c883d46
test for comet logger
awaelchli Jul 5, 2020
986b807
improve mlflow checkpoint test
awaelchli Jul 5, 2020
70d1ede
prevent commet logger error on pytest exit
awaelchli Jul 5, 2020
66b0182
test tensorboard save dir structure
awaelchli Jul 5, 2020
520bbd5
wandb save dir test
awaelchli Jul 5, 2020
1cc0afa
skip test on windows
awaelchli Jul 5, 2020
84f9cf4
Merge branch 'master' into bugfix/mlflow-fixes
awaelchli Jul 5, 2020
f6179b5
Merge branch 'master' into bugfix/mlflow-rework
Jul 6, 2020
a57f7db
add mlflow to pickle tests
Jul 6, 2020
c19d155
wandb
Jul 6, 2020
af8e2ac
code factor
Jul 6, 2020
61eaa71
remove unused imports
Jul 6, 2020
d93d6c9
Merge branch 'master' into bugfix/mlflow-rework
Jul 7, 2020
c98df78
remove unused setter
Jul 7, 2020
4ed9f0d
wandb mock
awaelchli Jul 7, 2020
8d2e46f
Merge remote-tracking branch 'PyTorchLightning/bugfix/mlflow-fixes' i…
awaelchli Jul 7, 2020
ab8c5e2
Merge remote-tracking branch 'original/bugfix/mlflow-fixes' into bugf…
Jul 7, 2020
ae963ff
wip mock
Jul 8, 2020
522838f
wip mock
Jul 8, 2020
7039384
wandb tests with mocking
Jul 8, 2020
66d2600
clean up
Jul 8, 2020
e0794b0
clean up
Jul 8, 2020
fcc7f08
comments
Jul 8, 2020
8773634
include wandblogger in test
Jul 8, 2020
000f531
clean up
Jul 8, 2020
d375056
missing argument
awaelchli Jul 8, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518))

- Fixed a problem with `MLflowLogger` creating multiple run folders ([#2502](https://github.com/PyTorchLightning/pytorch-lightning/pull/2502))

## [0.8.4] - 2020-07-01

### Added
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,9 @@ def on_train_start(self, trainer, pl_module):

if trainer.logger is not None:
# weights_save_path overrides anything
if getattr(trainer, 'weights_save_path', None) is not None:
save_dir = trainer.weights_save_path
else:
save_dir = (getattr(trainer.logger, 'save_dir', None)
or getattr(trainer.logger, '_save_dir', None)
or trainer.default_root_dir)
save_dir = (getattr(trainer, 'weights_save_path', None)
or getattr(trainer.logger, 'save_dir', None)
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved
or trainer.default_root_dir)

version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def close(self) -> None:
"""Do any cleanup that is necessary to close an experiment."""
self.save()

@property
def save_dir(self) -> Optional[str]:
"""
Return the root directory where experiment logs get saved, or `None` if the logger does not
save data locally.
"""
return None

@property
@abstractmethod
def name(self) -> str:
Expand Down
9 changes: 1 addition & 8 deletions pytorch_lightning/loggers/comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ def __init__(self,
self.comet_api = None

if experiment_name:
try:
self.name = experiment_name
except TypeError:
log.exception("Failed to set experiment name for comet.ml logger")
self.experiment.set_name(experiment_name)
self._kwargs = kwargs

@property
Expand Down Expand Up @@ -228,10 +225,6 @@ def save_dir(self) -> Optional[str]:
def name(self) -> str:
return str(self.experiment.project_name)

@name.setter
def name(self, value: str) -> None:
self.experiment.set_name(value)

@property
def version(self) -> str:
return self.experiment.id
Expand Down
81 changes: 54 additions & 27 deletions pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
MLflow
------
"""
import os
from argparse import Namespace
from time import time
from typing import Optional, Dict, Any, Union
Expand All @@ -11,16 +10,20 @@
import mlflow
from mlflow.tracking import MlflowClient
_MLFLOW_AVAILABLE = True
except ImportError: # pragma: no-cover
except ModuleNotFoundError: # pragma: no-cover
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
mlflow = None
MlflowClient = None
_MLFLOW_AVAILABLE = False


from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only


LOCAL_FILE_URI_PREFIX = "file:"


class MLFlowLogger(LightningLoggerBase):
"""
Log using `MLflow <https://mlflow.org>`_. Install it with pip:
Expand Down Expand Up @@ -52,59 +55,71 @@ class MLFlowLogger(LightningLoggerBase):
Args:
experiment_name: The name of the experiment
tracking_uri: Address of local or remote tracking server.
If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``.
If not provided, defaults to `file:<save_dir>`.
tags: A dictionary tags for the experiment.
save_dir: A path to a local directory where the MLflow runs get saved.
Defaults to `./mlflow` if `tracking_uri` is not provided.
Has no effect if `tracking_uri` is provided.

"""

def __init__(self,
experiment_name: str = 'default',
tracking_uri: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
save_dir: Optional[str] = None):
save_dir: Optional[str] = './mlruns'):

if not _MLFLOW_AVAILABLE:
raise ImportError('You want to use `mlflow` logger which is not installed yet,'
' install it with `pip install mlflow`.')
super().__init__()
if not tracking_uri and save_dir:
tracking_uri = f'file:{os.sep * 2}{save_dir}'
self._mlflow_client = MlflowClient(tracking_uri)
self.experiment_name = experiment_name
if not tracking_uri:
tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}'

self._experiment_name = experiment_name
self._experiment_id = None
self._tracking_uri = tracking_uri
self._run_id = None
self.tags = tags
self._mlflow_client = MlflowClient(tracking_uri)

@property
@rank_zero_experiment
def experiment(self) -> MlflowClient:
r"""
Actual MLflow object. To use mlflow features in your
Actual MLflow object. To use MLflow features in your
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.

Example::

self.logger.experiment.some_mlflow_function()

"""
return self._mlflow_client

@property
def run_id(self):
if self._run_id is not None:
return self._run_id

expt = self._mlflow_client.get_experiment_by_name(self.experiment_name)
expt = self._mlflow_client.get_experiment_by_name(self._experiment_name)

if expt:
self._expt_id = expt.experiment_id
self._experiment_id = expt.experiment_id
else:
log.warning(f'Experiment with name {self.experiment_name} not found. Creating it.')
self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name)
log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.')
self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name)

run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags)
self._run_id = run.info.run_id
if not self._run_id:
run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags)
self._run_id = run.info.run_id
return self._mlflow_client

@property
def run_id(self):
# create the experiment if it does not exist to get the run id
_ = self.experiment
return self._run_id

@property
def experiment_id(self):
# create the experiment if it does not exist to get the experiment id
_ = self.experiment
return self._experiment_id

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
params = self._convert_params(params)
Expand All @@ -126,14 +141,26 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
@rank_zero_only
def finalize(self, status: str = 'FINISHED') -> None:
super().finalize(status)
if status == 'success':
status = 'FINISHED'
self.experiment.set_terminated(self.run_id, status)
status = 'FINISHED' if status == 'success' else status
if self.experiment.get_run(self.run_id):
self.experiment.set_terminated(self.run_id, status)

@property
def save_dir(self) -> Optional[str]:
"""
The root file directory in which MLflow experiments are saved.

Return:
Local path to the root experiment directory if the tracking uri is local.
Otherwhise returns `None`.
"""
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

@property
def name(self) -> str:
return self.experiment_name
return self.experiment_id

@property
def version(self) -> str:
return self._run_id
return self.run_id
6 changes: 1 addition & 5 deletions pytorch_lightning/loggers/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
**kwargs):
super().__init__()
self._save_dir = save_dir
self._name = name
self._name = name or ''
self._version = version

self._experiment = None
Expand Down Expand Up @@ -106,10 +106,6 @@ def experiment(self) -> SummaryWriter:
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
return self._experiment

@experiment.setter
def experiment(self, exp):
self._experiment = exp

@rank_zero_only
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None) -> None:
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/loggers/test_tube.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self,
raise ImportError('You want to use `test_tube` logger which is not installed yet,'
' install it with `pip install test-tube`.')
super().__init__()
self.save_dir = save_dir
self._save_dir = save_dir
self._name = name
self.description = description
self.debug = debug
Expand Down Expand Up @@ -141,6 +141,10 @@ def close(self) -> None:
exp = self.experiment
exp.close()

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> str:
if self._experiment is None:
Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def experiment(self) -> Run:
group=self._group)
# save checkpoints in wandb dir to upload on W&B servers
if self._log_model:
self.save_dir = self._experiment.dir
self._save_dir = self._experiment.dir
return self._experiment

def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
Expand All @@ -133,13 +133,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->

self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)

@property
def save_dir(self) -> Optional[str]:
return self._save_dir

@property
def name(self) -> Optional[str]:
# don't create an experiment if we don't have one
name = self._experiment.project_name() if self._experiment else None
return name
return self._experiment.project_name() if self._experiment else self._name

@property
def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else None
return self._experiment.id if self._experiment else self._id
43 changes: 28 additions & 15 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import atexit
import inspect
import pickle
import platform
from unittest import mock

import pytest

Expand Down Expand Up @@ -35,14 +37,15 @@ def _get_logger_args(logger_class, save_dir):
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
WandbLogger,
])
def test_loggers_fit_test(tmpdir, monkeypatch, logger_class):
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

model = EvalModelTemplate()

Expand All @@ -58,6 +61,11 @@ def log_metrics(self, metrics, step):
logger_args = _get_logger_args(logger_class, tmpdir)
logger = StoreHistoryLogger(**logger_args)

if logger_class == WandbLogger:
# required mocks for Trainer
logger.experiment.id = 'foo'
logger.experiment.project_name.return_value = 'bar'

trainer = Trainer(
max_epochs=1,
logger=logger,
Expand All @@ -66,7 +74,6 @@ def log_metrics(self, metrics, step):
fast_dev_run=True,
)
trainer.fit(model)

trainer.test()

log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history]
Expand All @@ -78,17 +85,17 @@ def log_metrics(self, metrics, step):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
# MLFlowLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
# WandbLogger, # TODO: add this one
# The WandbLogger gets tested for pickling in its own test.
])
def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
"""Verify that pickling trainer with logger works."""
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
import atexit
monkeypatch.setattr(atexit, 'register', lambda _: None)
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
Expand Down Expand Up @@ -156,12 +163,18 @@ def on_batch_start(self, trainer, pl_module):
@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
# MLFlowLogger,
MLFlowLogger,
NeptuneLogger,
TestTubeLogger,
WandbLogger,
])
def test_logger_created_on_rank_zero_only(tmpdir, logger_class):
def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class):
""" Test that loggers get replaced by dummy logges on global rank > 0"""
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

logger_args = _get_logger_args(logger_class, tmpdir)
logger = logger_class(**logger_args)
model = EvalModelTemplate()
Expand Down
Loading