Skip to content

Commit

Permalink
Fix weights_save_path when logger is used + simplify path handling + …
Browse files Browse the repository at this point in the history
…better docs (#2681)

* fix weights_save path and drop ckpt_path

* add tests

* unused import

* update docs

* changelog

* pep8

* fix horovod test

* make backward compatible

* perform same test for all loggers

* fix for when logger=False and weights_save_path is set

* update changelog

* update docs

* update tests

* do not set save dir dynamically

* remove duplicate test

* remove duplicated tests

* update tests

* update tests

* remove remaining ckpt_path references

* move defaults to init as suggested by @Borda

* test deprecation
  • Loading branch information
awaelchli committed Jul 27, 2020
1 parent 3f2c102 commit d039532
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 41 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated Trainer attribute `ckpt_path`, which will now be set by `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

### Removed

Expand All @@ -31,6 +32,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

## [0.8.5] - 2020-07-09

### Added
Expand Down
34 changes: 22 additions & 12 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ class ModelCheckpoint(Callback):
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
Can also be set to `None`, then it will be set to default location
during trainer construction.
By default, filepath is `None` and will be set at runtime to the location
specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments,
and if the Trainer uses a logger, the path will also contain logger name and version.
monitor: quantity to monitor.
verbose: verbosity mode. Default: ``False``.
Expand Down Expand Up @@ -233,19 +236,29 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
@rank_zero_only
def on_train_start(self, trainer, pl_module):
"""
Determine model checkpoint save directory at runtime. References attributes from the
Trainer's logger to determine where to save checkpoints.
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
The base path for saving weights is set in this priority:
1. Checkpoint callback's path (if passed in)
2. The default_root_dir from trainer if trainer has no logger
3. The weights_save_path from trainer, if user provides it
4. User provided weights_saved_path
The base path gets extended with logger name and version (if these are available)
and subfolder "checkpoints".
"""
if self.dirpath is not None:
return # short circuit

self.filename = '{epoch}'

if trainer.logger is not None:
# weights_save_path overrides anything
save_dir = (getattr(trainer, 'weights_save_path', None)
or getattr(trainer.logger, 'save_dir', None)
or trainer.default_root_dir)
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
save_dir = trainer.logger.save_dir or trainer.default_root_dir

version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
Expand All @@ -256,15 +269,12 @@ def on_train_start(self, trainer, pl_module):
"checkpoints"
)
else:
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'

os.makedirs(self.dirpath, exist_ok=True)
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def on_train_end(self, trainer, pl_module):
)
default_root_dir
^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^
Default path for logs and weights when no logger
or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed.
On certain clusters you might want to separate where logs and checkpoints
are stored. If you don't then use this method for convenience.
are stored. If you don't then use this argument for convenience.
Example::
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def configure_checkpoint_callback(self, checkpoint_callback):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
if checkpoint_callback is True:
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overridden('validation_step')
Expand All @@ -53,10 +47,6 @@ def configure_checkpoint_callback(self, checkpoint_callback):
if checkpoint_callback:
checkpoint_callback.save_function = self.save_checkpoint

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_root_dir

return checkpoint_callback

def configure_early_stopping(self, early_stop_callback):
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class TrainerDeprecatedAPITillVer0_10(ABC):
limit_test_batches: Union[int, float]
limit_train_batches: Union[int, float]
overfit_batches: Union[int, float]
is_global_zero: bool
_weights_save_path: str
weights_save_path: str

def __init__(self):
super().__init__() # mixin calls super too
Expand Down Expand Up @@ -118,3 +121,17 @@ def proc_rank(self, rank):
rank_zero_warn("Attribute `proc_rank` is now set by `global_rank` since v0.8.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
self.global_rank = rank

@property
def ckpt_path(self) -> str:
"""Back compatibility, will be removed in v0.10.0"""
rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
return self.weights_save_path if self.is_global_zero else None

@ckpt_path.setter
def ckpt_path(self, path: str):
"""Back compatibility, will be removed in v0.10.0"""
rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
self._weights_save_path = path
27 changes: 21 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __init__(
callbacks: Add a list of callbacks.
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
gradient_clip_val: 0 means don't clip.
Expand Down Expand Up @@ -351,6 +352,7 @@ def __init__(
weights_save_path: Where to save weights if specified. Will override default_root_dir
for checkpoints only. Use this if for whatever reason you need the checkpoints
stored in a different place than the logs written in `default_root_dir`.
Defaults to `default_root_dir`.
amp_level: The optimization level to use (O1, O2, etc...).
Expand Down Expand Up @@ -437,10 +439,8 @@ def __init__(
self.should_stop = False
self.running_sanity_check = False

# set default save path if user didn't provide one
if default_root_dir is None:
default_root_dir = os.getcwd()
self.default_root_dir = default_root_dir
self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir

# init callbacks
self.callbacks = callbacks or []
Expand All @@ -454,7 +454,6 @@ def __init__(
# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
self.weights_save_path = weights_save_path
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
if checkpoint_callback:
self.callbacks.append(checkpoint_callback)
Expand Down Expand Up @@ -937,6 +936,22 @@ def enable_validation(self) -> bool:
val_loop_enabled = self.is_overridden('validation_step') and self.limit_val_batches > 0
return val_loop_enabled or self.fast_dev_run

@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
"""
return os.path.normpath(self._default_root_dir)

@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
"""
return os.path.normpath(self._weights_save_path)

# -----------------------------
# MODEL TRAINING
# -----------------------------
Expand Down
6 changes: 2 additions & 4 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
max_epochs=2,
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints'


@pytest.mark.parametrize(
Expand All @@ -51,7 +49,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
assert ckpt_version == expected


Expand Down
63 changes: 63 additions & 0 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import atexit
import inspect
import os
import pickle
import platform
from unittest import mock
Expand Down Expand Up @@ -82,6 +83,68 @@ def log_metrics(self, metrics, step):
(1, ['epoch', 'test_acc', 'test_loss'])]


@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
TestTubeLogger,
WandbLogger,
])
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class):
""" Test the combinations of save_dir, weights_save_path and default_root_dir. """
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

class TestLogger(logger_class):
# for this test it does not matter what these attributes are
# so we standardize them to make testing easier
@property
def version(self):
return 'version'

@property
def name(self):
return 'name'

model = EvalModelTemplate()
trainer_args = dict(
default_root_dir=tmpdir,
max_steps=1,
)

# no weights_save_path given
save_dir = tmpdir / 'logs'
weights_save_path = None
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == trainer.default_root_dir
assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, 'name', 'version', 'checkpoints')
assert trainer.default_root_dir == tmpdir

# with weights_save_path given, the logger path and checkpoint path should be different
save_dir = tmpdir / 'logs'
weights_save_path = tmpdir / 'weights'
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == weights_save_path
assert trainer.logger.save_dir == save_dir
assert trainer.checkpoint_callback.dirpath == weights_save_path / 'name' / 'version' / 'checkpoints'
assert trainer.default_root_dir == tmpdir

# no logger given
weights_save_path = tmpdir / 'weights'
trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == weights_save_path
assert trainer.checkpoint_callback.dirpath == weights_save_path / 'checkpoints'
assert trainer.default_root_dir == tmpdir


@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
Expand Down
4 changes: 2 additions & 2 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
4 changes: 2 additions & 2 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
4 changes: 2 additions & 2 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
1 change: 0 additions & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def run_test_from_config(trainer_options):
if trainer.global_rank > 0:
# on higher ranks the checkpoint location is unknown
# we want to test checkpointing on rank 0 only
assert not hasattr(trainer, 'ckpt_path')
assert not trainer.checkpoint_callback.best_model_path
return

Expand Down
4 changes: 4 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def test_tbd_remove_in_v0_10_0_trainer():
with pytest.deprecated_call(match='will be removed in v0.10.0'):
assert trainer.proc_rank == trainer.global_rank

with pytest.deprecated_call(match='will be removed in v0.10.0'):
trainer.ckpt_path = 'foo'
assert trainer.ckpt_path == trainer.weights_save_path == 'foo'


def test_tbd_remove_in_v0_9_0_trainer():
# test show_progress_bar set by progress_bar_refresh_rate
Expand Down

0 comments on commit d039532

Please sign in to comment.