diff --git a/CHANGELOG.md b/CHANGELOG.md index 022e3159e0cb7..5a8df1196f7ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated Trainer attribute `ckpt_path`, which will now be set by `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) ### Removed @@ -31,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) +- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) + +- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681)) + ## [0.8.5] - 2020-07-09 ### Added diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b6a92efc53321..bfade6f024ba8 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -41,8 +41,11 @@ class ModelCheckpoint(Callback): ... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - Can also be set to `None`, then it will be set to default location - during trainer construction. + By default, filepath is `None` and will be set at runtime to the location + specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, + and if the Trainer uses a logger, the path will also contain logger name and version. monitor: quantity to monitor. verbose: verbosity mode. Default: ``False``. @@ -233,8 +236,17 @@ def format_checkpoint_name(self, epoch, metrics, ver=None): @rank_zero_only def on_train_start(self, trainer, pl_module): """ - Determine model checkpoint save directory at runtime. References attributes from the - Trainer's logger to determine where to save checkpoints. + Determines model checkpoint save directory at runtime. References attributes from the + trainer's logger to determine where to save checkpoints. + The base path for saving weights is set in this priority: + + 1. Checkpoint callback's path (if passed in) + 2. The default_root_dir from trainer if trainer has no logger + 3. The weights_save_path from trainer, if user provides it + 4. User provided weights_saved_path + + The base path gets extended with logger name and version (if these are available) + and subfolder "checkpoints". """ if self.dirpath is not None: return # short circuit @@ -242,10 +254,11 @@ def on_train_start(self, trainer, pl_module): self.filename = '{epoch}' if trainer.logger is not None: - # weights_save_path overrides anything - save_dir = (getattr(trainer, 'weights_save_path', None) - or getattr(trainer.logger, 'save_dir', None) - or trainer.default_root_dir) + if trainer.weights_save_path != trainer.default_root_dir: + # the user has changed weights_save_path, it overrides anything + save_dir = trainer.weights_save_path + else: + save_dir = trainer.logger.save_dir or trainer.default_root_dir version = trainer.logger.version if isinstance( trainer.logger.version, str) else f'version_{trainer.logger.version}' @@ -256,15 +269,12 @@ def on_train_start(self, trainer, pl_module): "checkpoints" ) else: - ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") + ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints") self.dirpath = ckpt_path assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' - os.makedirs(self.dirpath, exist_ok=True) - trainer.ckpt_path = ckpt_path - trainer.weights_save_path = ckpt_path @rank_zero_only def on_validation_end(self, trainer, pl_module): diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index b8f02a36c6806..7e188ab97492c 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -291,12 +291,12 @@ def on_train_end(self, trainer, pl_module): ) default_root_dir -^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^ Default path for logs and weights when no logger or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On certain clusters you might want to separate where logs and checkpoints -are stored. If you don't then use this method for convenience. +are stored. If you don't then use this argument for convenience. Example:: diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index b65dc37ef8b1b..8600449d86a94 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -33,12 +33,6 @@ def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def configure_checkpoint_callback(self, checkpoint_callback): - """ - Weight path set in this priority: - Checkpoint_callback's path (if passed in). - User provided weights_saved_path - Otherwise use os.getcwd() - """ if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not self.is_overridden('validation_step') @@ -53,10 +47,6 @@ def configure_checkpoint_callback(self, checkpoint_callback): if checkpoint_callback: checkpoint_callback.save_function = self.save_checkpoint - # if weights_save_path is still none here, set to current working dir - if self.weights_save_path is None: - self.weights_save_path = self.default_root_dir - return checkpoint_callback def configure_early_stopping(self, early_stop_callback): diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 6cb160f1d26b2..f38ff6a486874 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -45,6 +45,9 @@ class TrainerDeprecatedAPITillVer0_10(ABC): limit_test_batches: Union[int, float] limit_train_batches: Union[int, float] overfit_batches: Union[int, float] + is_global_zero: bool + _weights_save_path: str + weights_save_path: str def __init__(self): super().__init__() # mixin calls super too @@ -118,3 +121,17 @@ def proc_rank(self, rank): rank_zero_warn("Attribute `proc_rank` is now set by `global_rank` since v0.8.0" " and this method will be removed in v0.10.0", DeprecationWarning) self.global_rank = rank + + @property + def ckpt_path(self) -> str: + """Back compatibility, will be removed in v0.10.0""" + rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" + " and this method will be removed in v0.10.0", DeprecationWarning) + return self.weights_save_path if self.is_global_zero else None + + @ckpt_path.setter + def ckpt_path(self, path: str): + """Back compatibility, will be removed in v0.10.0""" + rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0" + " and this method will be removed in v0.10.0", DeprecationWarning) + self._weights_save_path = path diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e92bb27a61cd3..876380ee22cab 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -224,7 +224,8 @@ def __init__( callbacks: Add a list of callbacks. - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed + default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. + Default: ``os.getcwd()``. gradient_clip_val: 0 means don't clip. @@ -351,6 +352,7 @@ def __init__( weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. + Defaults to `default_root_dir`. amp_level: The optimization level to use (O1, O2, etc...). @@ -437,10 +439,8 @@ def __init__( self.should_stop = False self.running_sanity_check = False - # set default save path if user didn't provide one - if default_root_dir is None: - default_root_dir = os.getcwd() - self.default_root_dir = default_root_dir + self._default_root_dir = default_root_dir or os.getcwd() + self._weights_save_path = weights_save_path or self._default_root_dir # init callbacks self.callbacks = callbacks or [] @@ -454,7 +454,6 @@ def __init__( # configure checkpoint callback # it is important that this is the last callback to run # pass through the required args to figure out defaults - self.weights_save_path = weights_save_path checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) if checkpoint_callback: self.callbacks.append(checkpoint_callback) @@ -937,6 +936,22 @@ def enable_validation(self) -> bool: val_loop_enabled = self.is_overridden('validation_step') and self.limit_val_batches > 0 return val_loop_enabled or self.fast_dev_run + @property + def default_root_dir(self) -> str: + """ + The default location to save artifacts of loggers, checkpoints etc. + It is used as a fallback if logger or checkpoint callback do not define specific save paths. + """ + return os.path.normpath(self._default_root_dir) + + @property + def weights_save_path(self) -> str: + """ + The default root location to save weights (checkpoints), e.g., when the + :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. + """ + return os.path.normpath(self._weights_save_path) + # ----------------------------- # MODEL TRAINING # ----------------------------- diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index bb575494c3148..4cb52a54610e3 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -28,9 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): max_epochs=2, ) trainer.fit(model) - - # These should be different if the dirpath has be overridden - assert trainer.ckpt_path != trainer.default_root_dir + assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints' @pytest.mark.parametrize( @@ -51,7 +49,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected): ) trainer.fit(model) - ckpt_version = Path(trainer.ckpt_path).parent.name + ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name assert ckpt_version == expected diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index b64119078c6dd..e3d6202d05932 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -1,5 +1,6 @@ import atexit import inspect +import os import pickle import platform from unittest import mock @@ -82,6 +83,68 @@ def log_metrics(self, metrics, step): (1, ['epoch', 'test_acc', 'test_loss'])] +@pytest.mark.parametrize("logger_class", [ + TensorBoardLogger, + CometLogger, + MLFlowLogger, + TestTubeLogger, + WandbLogger, +]) +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class): + """ Test the combinations of save_dir, weights_save_path and default_root_dir. """ + if logger_class == CometLogger: + # prevent comet logger from trying to print at exit, since + # pytest's stdout/stderr redirection breaks it + monkeypatch.setattr(atexit, 'register', lambda _: None) + + class TestLogger(logger_class): + # for this test it does not matter what these attributes are + # so we standardize them to make testing easier + @property + def version(self): + return 'version' + + @property + def name(self): + return 'name' + + model = EvalModelTemplate() + trainer_args = dict( + default_root_dir=tmpdir, + max_steps=1, + ) + + # no weights_save_path given + save_dir = tmpdir / 'logs' + weights_save_path = None + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == trainer.default_root_dir + assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, 'name', 'version', 'checkpoints') + assert trainer.default_root_dir == tmpdir + + # with weights_save_path given, the logger path and checkpoint path should be different + save_dir = tmpdir / 'logs' + weights_save_path = tmpdir / 'weights' + logger = TestLogger(**_get_logger_args(TestLogger, save_dir)) + trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == weights_save_path + assert trainer.logger.save_dir == save_dir + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'name' / 'version' / 'checkpoints' + assert trainer.default_root_dir == tmpdir + + # no logger given + weights_save_path = tmpdir / 'weights' + trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path) + trainer.fit(model) + assert trainer.weights_save_path == weights_save_path + assert trainer.checkpoint_callback.dirpath == weights_save_path / 'checkpoints' + assert trainer.default_root_dir == tmpdir + + @pytest.mark.parametrize("logger_class", [ TensorBoardLogger, CometLogger, diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a89840163fe7a..a3ba883a65ae3 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -103,5 +103,5 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index ec9bc8db332a4..31b580f33f6d4 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -37,5 +37,5 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'} assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() - assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 57b0aff311264..9907ad9d087a2 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -94,5 +94,5 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'} + assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index a71d7b576ca1f..7138021e8e7e9 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -63,7 +63,6 @@ def run_test_from_config(trainer_options): if trainer.global_rank > 0: # on higher ranks the checkpoint location is unknown # we want to test checkpointing on rank 0 only - assert not hasattr(trainer, 'ckpt_path') assert not trainer.checkpoint_callback.best_model_path return diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index e6eb86e42c1fc..9d4d69faa30bf 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -50,6 +50,10 @@ def test_tbd_remove_in_v0_10_0_trainer(): with pytest.deprecated_call(match='will be removed in v0.10.0'): assert trainer.proc_rank == trainer.global_rank + with pytest.deprecated_call(match='will be removed in v0.10.0'): + trainer.ckpt_path = 'foo' + assert trainer.ckpt_path == trainer.weights_save_path == 'foo' + def test_tbd_remove_in_v0_9_0_trainer(): # test show_progress_bar set by progress_bar_refresh_rate