From f16b4cfc522db579b053e902add07aba59ce0ac5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 9 Jul 2020 13:15:41 +0200 Subject: [PATCH] save_dir fix for MLflowLogger + save_dir tests for others (#2502) * mlflow rework * logger save_dir * folder * mlflow * simplify * fix test * add a test for file dir contents * new line * changelog * docs * Update CHANGELOG.md Co-authored-by: Jirka Borovec * test for comet logger * improve mlflow checkpoint test * prevent commet logger error on pytest exit * test tensorboard save dir structure * wandb save dir test * skip test on windows * add mlflow to pickle tests * wandb * code factor * remove unused imports * remove unused setter * wandb mock * wip mock * wip mock * wandb tests with mocking * clean up * clean up * comments * include wandblogger in test * clean up * missing argument Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 + .../callbacks/model_checkpoint.py | 9 +-- pytorch_lightning/loggers/base.py | 8 ++ pytorch_lightning/loggers/comet.py | 9 +-- pytorch_lightning/loggers/mlflow.py | 81 ++++++++++++------- pytorch_lightning/loggers/tensorboard.py | 6 +- pytorch_lightning/loggers/test_tube.py | 6 +- pytorch_lightning/loggers/wandb.py | 11 ++- tests/loggers/test_all.py | 43 ++++++---- tests/loggers/test_comet.py | 27 +++++++ tests/loggers/test_mlflow.py | 36 ++++++++- tests/loggers/test_tensorboard.py | 30 ++++--- tests/loggers/test_wandb.py | 35 +++++++- 13 files changed, 219 insertions(+), 84 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 61f44603c7de0..c7b30e900f343 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made `TensorBoardLogger` and `CometLogger` pickleable ([#2518](https://github.com/PyTorchLightning/pytorch-lightning/pull/2518)) +- Fixed a problem with `MLflowLogger` creating multiple run folders ([#2502](https://github.com/PyTorchLightning/pytorch-lightning/pull/2502)) + ## [0.8.4] - 2020-07-01 ### Added diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index a3e3174fadeb0..f70d8d8d0a5e1 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -239,12 +239,9 @@ def on_train_start(self, trainer, pl_module): if trainer.logger is not None: # weights_save_path overrides anything - if getattr(trainer, 'weights_save_path', None) is not None: - save_dir = trainer.weights_save_path - else: - save_dir = (getattr(trainer.logger, 'save_dir', None) - or getattr(trainer.logger, '_save_dir', None) - or trainer.default_root_dir) + save_dir = (getattr(trainer, 'weights_save_path', None) + or getattr(trainer.logger, 'save_dir', None) + or trainer.default_root_dir) version = trainer.logger.version if isinstance( trainer.logger.version, str) else f'version_{trainer.logger.version}' diff --git a/pytorch_lightning/loggers/base.py b/pytorch_lightning/loggers/base.py index 17b08c7df915d..c88cbce27139e 100644 --- a/pytorch_lightning/loggers/base.py +++ b/pytorch_lightning/loggers/base.py @@ -237,6 +237,14 @@ def close(self) -> None: """Do any cleanup that is necessary to close an experiment.""" self.save() + @property + def save_dir(self) -> Optional[str]: + """ + Return the root directory where experiment logs get saved, or `None` if the logger does not + save data locally. + """ + return None + @property @abstractmethod def name(self) -> str: diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 5c5cc29fa7778..a328d4e3a0d53 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -134,10 +134,7 @@ def __init__(self, self.comet_api = None if experiment_name: - try: - self.name = experiment_name - except TypeError: - log.exception("Failed to set experiment name for comet.ml logger") + self.experiment.set_name(experiment_name) self._kwargs = kwargs @property @@ -228,10 +225,6 @@ def save_dir(self) -> Optional[str]: def name(self) -> str: return str(self.experiment.project_name) - @name.setter - def name(self, value: str) -> None: - self.experiment.set_name(value) - @property def version(self) -> str: return self.experiment.id diff --git a/pytorch_lightning/loggers/mlflow.py b/pytorch_lightning/loggers/mlflow.py index f61700da38614..0465e58278d2e 100644 --- a/pytorch_lightning/loggers/mlflow.py +++ b/pytorch_lightning/loggers/mlflow.py @@ -2,7 +2,6 @@ MLflow ------ """ -import os from argparse import Namespace from time import time from typing import Optional, Dict, Any, Union @@ -11,16 +10,20 @@ import mlflow from mlflow.tracking import MlflowClient _MLFLOW_AVAILABLE = True -except ImportError: # pragma: no-cover +except ModuleNotFoundError: # pragma: no-cover mlflow = None MlflowClient = None _MLFLOW_AVAILABLE = False + from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities import rank_zero_only +LOCAL_FILE_URI_PREFIX = "file:" + + class MLFlowLogger(LightningLoggerBase): """ Log using `MLflow `_. Install it with pip: @@ -52,8 +55,11 @@ class MLFlowLogger(LightningLoggerBase): Args: experiment_name: The name of the experiment tracking_uri: Address of local or remote tracking server. - If not provided, defaults to the service set by ``mlflow.tracking.set_tracking_uri``. + If not provided, defaults to `file:`. tags: A dictionary tags for the experiment. + save_dir: A path to a local directory where the MLflow runs get saved. + Defaults to `./mlflow` if `tracking_uri` is not provided. + Has no effect if `tracking_uri` is provided. """ @@ -61,24 +67,27 @@ def __init__(self, experiment_name: str = 'default', tracking_uri: Optional[str] = None, tags: Optional[Dict[str, Any]] = None, - save_dir: Optional[str] = None): + save_dir: Optional[str] = './mlruns'): if not _MLFLOW_AVAILABLE: raise ImportError('You want to use `mlflow` logger which is not installed yet,' ' install it with `pip install mlflow`.') super().__init__() - if not tracking_uri and save_dir: - tracking_uri = f'file:{os.sep * 2}{save_dir}' - self._mlflow_client = MlflowClient(tracking_uri) - self.experiment_name = experiment_name + if not tracking_uri: + tracking_uri = f'{LOCAL_FILE_URI_PREFIX}{save_dir}' + + self._experiment_name = experiment_name + self._experiment_id = None + self._tracking_uri = tracking_uri self._run_id = None self.tags = tags + self._mlflow_client = MlflowClient(tracking_uri) @property @rank_zero_experiment def experiment(self) -> MlflowClient: r""" - Actual MLflow object. To use mlflow features in your + Actual MLflow object. To use MLflow features in your :class:`~pytorch_lightning.core.lightning.LightningModule` do the following. Example:: @@ -86,25 +95,31 @@ def experiment(self) -> MlflowClient: self.logger.experiment.some_mlflow_function() """ - return self._mlflow_client - - @property - def run_id(self): - if self._run_id is not None: - return self._run_id - - expt = self._mlflow_client.get_experiment_by_name(self.experiment_name) + expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) if expt: - self._expt_id = expt.experiment_id + self._experiment_id = expt.experiment_id else: - log.warning(f'Experiment with name {self.experiment_name} not found. Creating it.') - self._expt_id = self._mlflow_client.create_experiment(name=self.experiment_name) + log.warning(f'Experiment with name {self._experiment_name} not found. Creating it.') + self._experiment_id = self._mlflow_client.create_experiment(name=self._experiment_name) - run = self._mlflow_client.create_run(experiment_id=self._expt_id, tags=self.tags) - self._run_id = run.info.run_id + if not self._run_id: + run = self._mlflow_client.create_run(experiment_id=self._experiment_id, tags=self.tags) + self._run_id = run.info.run_id + return self._mlflow_client + + @property + def run_id(self): + # create the experiment if it does not exist to get the run id + _ = self.experiment return self._run_id + @property + def experiment_id(self): + # create the experiment if it does not exist to get the experiment id + _ = self.experiment + return self._experiment_id + @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: params = self._convert_params(params) @@ -126,14 +141,26 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> @rank_zero_only def finalize(self, status: str = 'FINISHED') -> None: super().finalize(status) - if status == 'success': - status = 'FINISHED' - self.experiment.set_terminated(self.run_id, status) + status = 'FINISHED' if status == 'success' else status + if self.experiment.get_run(self.run_id): + self.experiment.set_terminated(self.run_id, status) + + @property + def save_dir(self) -> Optional[str]: + """ + The root file directory in which MLflow experiments are saved. + + Return: + Local path to the root experiment directory if the tracking uri is local. + Otherwhise returns `None`. + """ + if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): + return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX) @property def name(self) -> str: - return self.experiment_name + return self.experiment_id @property def version(self) -> str: - return self._run_id + return self.run_id diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 29cdd49c0efbe..23394b93f7882 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -51,7 +51,7 @@ def __init__(self, **kwargs): super().__init__() self._save_dir = save_dir - self._name = name + self._name = name or '' self._version = version self._experiment = None @@ -106,10 +106,6 @@ def experiment(self) -> SummaryWriter: self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) return self._experiment - @experiment.setter - def experiment(self, exp): - self._experiment = exp - @rank_zero_only def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None) -> None: diff --git a/pytorch_lightning/loggers/test_tube.py b/pytorch_lightning/loggers/test_tube.py index 38e3d091f7e6a..292b5cd1a1ee6 100644 --- a/pytorch_lightning/loggers/test_tube.py +++ b/pytorch_lightning/loggers/test_tube.py @@ -68,7 +68,7 @@ def __init__(self, raise ImportError('You want to use `test_tube` logger which is not installed yet,' ' install it with `pip install test-tube`.') super().__init__() - self.save_dir = save_dir + self._save_dir = save_dir self._name = name self.description = description self.debug = debug @@ -141,6 +141,10 @@ def close(self) -> None: exp = self.experiment exp.close() + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def name(self) -> str: if self._experiment is None: diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index d0a139266ca06..ad0b9c587101f 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -116,7 +116,7 @@ def experiment(self) -> Run: group=self._group) # save checkpoints in wandb dir to upload on W&B servers if self._log_model: - self.save_dir = self._experiment.dir + self._save_dir = self._experiment.dir return self._experiment def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100): @@ -134,13 +134,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log({'global_step': step, **metrics} if step is not None else metrics) + @property + def save_dir(self) -> Optional[str]: + return self._save_dir + @property def name(self) -> Optional[str]: # don't create an experiment if we don't have one - name = self._experiment.project_name() if self._experiment else None - return name + return self._experiment.project_name() if self._experiment else self._name @property def version(self) -> Optional[str]: # don't create an experiment if we don't have one - return self._experiment.id if self._experiment else None + return self._experiment.id if self._experiment else self._id diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index f74e815086d6f..b64119078c6dd 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -1,6 +1,8 @@ +import atexit import inspect import pickle import platform +from unittest import mock import pytest @@ -35,14 +37,15 @@ def _get_logger_args(logger_class, save_dir): MLFlowLogger, NeptuneLogger, TestTubeLogger, - # WandbLogger, # TODO: add this one + WandbLogger, ]) -def test_loggers_fit_test(tmpdir, monkeypatch, logger_class): +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_loggers_fit_test(wandb, tmpdir, monkeypatch, logger_class): """Verify that basic functionality of all loggers.""" - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - monkeypatch.setattr(atexit, 'register', lambda _: None) + if logger_class == CometLogger: + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + monkeypatch.setattr(atexit, 'register', lambda _: None) model = EvalModelTemplate() @@ -58,6 +61,11 @@ def log_metrics(self, metrics, step): logger_args = _get_logger_args(logger_class, tmpdir) logger = StoreHistoryLogger(**logger_args) + if logger_class == WandbLogger: + # required mocks for Trainer + logger.experiment.id = 'foo' + logger.experiment.project_name.return_value = 'bar' + trainer = Trainer( max_epochs=1, logger=logger, @@ -66,7 +74,6 @@ def log_metrics(self, metrics, step): fast_dev_run=True, ) trainer.fit(model) - trainer.test() log_metric_names = [(s, sorted(m.keys())) for s, m in logger.history] @@ -78,17 +85,17 @@ def log_metrics(self, metrics, step): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, - # MLFlowLogger, + MLFlowLogger, NeptuneLogger, TestTubeLogger, - # WandbLogger, # TODO: add this one + # The WandbLogger gets tested for pickling in its own test. ]) def test_loggers_pickle(tmpdir, monkeypatch, logger_class): """Verify that pickling trainer with logger works.""" - # prevent comet logger from trying to print at exit, since - # pytest's stdout/stderr redirection breaks it - import atexit - monkeypatch.setattr(atexit, 'register', lambda _: None) + if logger_class == CometLogger: + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + monkeypatch.setattr(atexit, 'register', lambda _: None) logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) @@ -156,12 +163,18 @@ def on_batch_start(self, trainer, pl_module): @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, - # MLFlowLogger, + MLFlowLogger, NeptuneLogger, TestTubeLogger, WandbLogger, ]) -def test_logger_created_on_rank_zero_only(tmpdir, logger_class): +def test_logger_created_on_rank_zero_only(tmpdir, monkeypatch, logger_class): + """ Test that loggers get replaced by dummy logges on global rank > 0""" + if logger_class == CometLogger: + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + monkeypatch.setattr(atexit, 'register', lambda _: None) + logger_args = _get_logger_args(logger_class, tmpdir) logger = logger_class(**logger_args) model = EvalModelTemplate() diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index aeab10cd0fbd9..a89840163fe7a 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -1,9 +1,12 @@ +import os from unittest.mock import patch import pytest +from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate def test_comet_logger_online(): @@ -78,3 +81,27 @@ def test_comet_logger_online(): ) api.assert_called_once_with('rest') + + +def test_comet_logger_dirs_creation(tmpdir, monkeypatch): + """ Test that the logger creates the folders and files in the right place. """ + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + import atexit + monkeypatch.setattr(atexit, 'register', lambda _: None) + + logger = CometLogger(project_name='test', save_dir=tmpdir) + assert not os.listdir(tmpdir) + assert logger.mode == 'offline' + assert logger.save_dir == tmpdir + + _ = logger.experiment + version = logger.version + assert set(os.listdir(tmpdir)) == {f'{logger.experiment.id}.zip'} + + model = EvalModelTemplate() + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer.fit(model) + + assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') + assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index 81ce25ca6347e..ec9bc8db332a4 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -1,9 +1,41 @@ +import os + +from pytorch_lightning import Trainer from pytorch_lightning.loggers import MLFlowLogger +from tests.base import EvalModelTemplate def test_mlflow_logger_exists(tmpdir): - """Verify that basic functionality of mlflow logger works.""" + """ Test launching two independent loggers. """ logger = MLFlowLogger('test', save_dir=tmpdir) - # Test already exists + # same name leads to same experiment id, but different runs get recorded logger2 = MLFlowLogger('test', save_dir=tmpdir) + assert logger.experiment_id == logger2.experiment_id assert logger.run_id != logger2.run_id + logger3 = MLFlowLogger('new', save_dir=tmpdir) + assert logger3.experiment_id != logger.experiment_id + + +def test_mlflow_logger_dirs_creation(tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + assert not os.listdir(tmpdir) + logger = MLFlowLogger('test', save_dir=tmpdir) + assert logger.save_dir == tmpdir + assert set(os.listdir(tmpdir)) == {'.trash'} + run_id = logger.run_id + exp_id = logger.experiment_id + + # multiple experiment calls should not lead to new experiment folders + for i in range(2): + _ = logger.experiment + assert set(os.listdir(tmpdir)) == {'.trash', exp_id} + assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} + + model = EvalModelTemplate() + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer.fit(model) + assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} + assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') + assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() + assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') + assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index d3362abc9ad44..44009a2ddf658 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -26,7 +26,7 @@ def test_tensorboard_hparams_reload(tmpdir): with open(os.path.join(folder_path, 'hparams.yaml')) as file: # The FullLoader parameter handles the conversion from YAML # scalar values to Python the dictionary format - yaml_params = yaml.load(file, Loader=yaml.FullLoader) + yaml_params = yaml.safe_load(file) assert yaml_params['b1'] == 0.5 assert len(yaml_params.keys()) == 10 @@ -48,22 +48,23 @@ def test_tensorboard_hparams_reload(tmpdir): def test_tensorboard_automatic_versioning(tmpdir): """Verify that automatic versioning works""" - root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("version_0") - root_dir.mkdir("version_1") + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning") - assert logger.version == 2 def test_tensorboard_manual_versioning(tmpdir): """Verify that manual versioning works""" - root_dir = tmpdir.mkdir("tb_versioning") - root_dir.mkdir("version_0") - root_dir.mkdir("version_1") - root_dir.mkdir("version_2") + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() + (root_dir / "version_2").mkdir() logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=1) @@ -73,22 +74,25 @@ def test_tensorboard_manual_versioning(tmpdir): def test_tensorboard_named_version(tmpdir): """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ - tmpdir.mkdir("tb_versioning") + name = "tb_versioning" + (tmpdir / name).mkdir() expected_version = "2020-02-05-162402" - logger = TensorBoardLogger(save_dir=tmpdir, name="tb_versioning", version=expected_version) + logger = TensorBoardLogger(save_dir=tmpdir, name=name, version=expected_version) logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written assert logger.version == expected_version - # Could also test existence of the directory but this fails - # in the "minimum requirements" test setup + assert os.listdir(tmpdir / name) == [expected_version] + assert os.listdir(tmpdir / name / expected_version) @pytest.mark.parametrize("name", ['', None]) def test_tensorboard_no_name(tmpdir, name): """Verify that None or empty name works""" logger = TensorBoardLogger(save_dir=tmpdir, name=name) + logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written assert logger.root_dir == tmpdir + assert os.listdir(tmpdir / 'version_0') @pytest.mark.parametrize("step_idx", [10, None]) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 182b518d1e52d..843a539bc921f 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -4,6 +4,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger +from tests.base import EvalModelTemplate @mock.patch('pytorch_lightning.loggers.wandb.wandb') @@ -34,18 +35,18 @@ def test_wandb_logger(wandb): @mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_wandb_pickle(wandb, tmpdir): - """Verify that pickling trainer with wandb logger works. - + """ + Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. """ class Experiment: + """ """ id = 'the_id' def project_name(self): return 'the_project_name' wandb.init.return_value = Experiment() - logger = WandbLogger(id='the_id', offline=True) trainer = Trainer( @@ -67,3 +68,31 @@ def project_name(self): assert wandb.init.call_args[1]['id'] == 'the_id' del os.environ['WANDB_MODE'] + + +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_logger_dirs_creation(wandb, tmpdir): + """ Test that the logger creates the folders and files in the right place. """ + logger = WandbLogger(save_dir=str(tmpdir), offline=True) + assert logger.version is None + assert logger.name is None + + # mock return values of experiment + logger.experiment.id = '1' + logger.experiment.project_name.return_value = 'project' + + for _ in range(2): + _ = logger.experiment + + assert logger.version == '1' + assert logger.name == 'project' + assert str(tmpdir) == logger.save_dir + assert not os.listdir(tmpdir) + + version = logger.version + model = EvalModelTemplate() + trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) + trainer.fit(model) + + assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') + assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}