Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix weights_save_path when logger is used + simplify path handling + better docs #2681

Merged
merged 24 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated Trainer attribute `ckpt_path`, which will now be set by `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

### Removed

Expand All @@ -29,6 +30,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

- Fixed `save_dir` in loggers getting ignored by default value of `weights_save_path` when user did not specify `weights_save_path` ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

- Fixed `weights_save_path` getting ignored when `logger=False` is passed to Trainer ([#2681](https://github.com/PyTorchLightning/pytorch-lightning/pull/2681))

## [0.8.5] - 2020-07-09

### Added
Expand Down
34 changes: 22 additions & 12 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ class ModelCheckpoint(Callback):
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )

Can also be set to `None`, then it will be set to default location
during trainer construction.
By default, filepath is `None` and will be set at runtime to the location
specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments,
and if the Trainer uses a logger, the path will also contain logger name and version.

monitor: quantity to monitor.
verbose: verbosity mode. Default: ``False``.
Expand Down Expand Up @@ -233,19 +236,29 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
@rank_zero_only
def on_train_start(self, trainer, pl_module):
"""
Determine model checkpoint save directory at runtime. References attributes from the
Trainer's logger to determine where to save checkpoints.
Determines model checkpoint save directory at runtime. References attributes from the
trainer's logger to determine where to save checkpoints.
The base path for saving weights is set in this priority:

1. Checkpoint callback's path (if passed in)
2. The default_root_dir from trainer if trainer has no logger
3. The weights_save_path from trainer, if user provides it
4. User provided weights_saved_path

The base path gets extended with logger name and version (if these are available)
and subfolder "checkpoints".
"""
if self.dirpath is not None:
return # short circuit

self.filename = '{epoch}'

if trainer.logger is not None:
# weights_save_path overrides anything
save_dir = (getattr(trainer, 'weights_save_path', None)
or getattr(trainer.logger, 'save_dir', None)
or trainer.default_root_dir)
if trainer.weights_save_path != trainer.default_root_dir:
# the user has changed weights_save_path, it overrides anything
save_dir = trainer.weights_save_path
else:
save_dir = trainer.logger.save_dir or trainer.default_root_dir

version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
Expand All @@ -256,15 +269,12 @@ def on_train_start(self, trainer, pl_module):
"checkpoints"
)
else:
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")

self.dirpath = ckpt_path

assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'

os.makedirs(self.dirpath, exist_ok=True)
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,12 @@ def on_train_end(self, trainer, pl_module):
)

default_root_dir
^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^

Default path for logs and weights when no logger
or :class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed.
On certain clusters you might want to separate where logs and checkpoints
are stored. If you don't then use this method for convenience.
are stored. If you don't then use this argument for convenience.

Example::

Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,6 @@ def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def configure_checkpoint_callback(self, checkpoint_callback):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
if checkpoint_callback is True:
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overridden('validation_step')
Expand All @@ -53,10 +47,6 @@ def configure_checkpoint_callback(self, checkpoint_callback):
if checkpoint_callback:
checkpoint_callback.save_function = self.save_checkpoint

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this what will be the path? just none?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it will still be default_root_dir, it's just getting set in the Trainer init now.

self.weights_save_path = self.default_root_dir

return checkpoint_callback

def configure_early_stopping(self, early_stop_callback):
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class TrainerDeprecatedAPITillVer0_10(ABC):
limit_test_batches: Union[int, float]
limit_train_batches: Union[int, float]
overfit_batches: Union[int, float]
weights_save_path: str
is_global_zero: bool

def __init__(self):
super().__init__() # mixin calls super too
Expand Down Expand Up @@ -118,3 +120,17 @@ def proc_rank(self, rank):
rank_zero_warn("Attribute `proc_rank` is now set by `global_rank` since v0.8.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
self.global_rank = rank

@property
def ckpt_path(self) -> str:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Back compatibility, will be removed in v0.10.0"""
rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
return self.weights_save_path if self.is_global_zero else None

@ckpt_path.setter
def ckpt_path(self, path: str):
"""Back compatibility, will be removed in v0.10.0"""
rank_zero_warn("Attribute `ckpt_path` is now set by `weights_save_path` since v0.9.0"
" and this method will be removed in v0.10.0", DeprecationWarning)
self.weights_save_path = path
29 changes: 24 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,7 @@ def __init__(
self.should_stop = False
self.running_sanity_check = False

# set default save path if user didn't provide one
if default_root_dir is None:
default_root_dir = os.getcwd()
self.default_root_dir = default_root_dir
self._default_root_dir = default_root_dir

# init callbacks
self.callbacks = callbacks or []
Expand All @@ -436,7 +433,7 @@ def __init__(
# configure checkpoint callback
# it is important that this is the last callback to run
# pass through the required args to figure out defaults
self.weights_save_path = weights_save_path
self._weights_save_path = weights_save_path
checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback)
if checkpoint_callback:
self.callbacks.append(checkpoint_callback)
Expand Down Expand Up @@ -894,6 +891,28 @@ def enable_validation(self) -> bool:
val_loop_enabled = (self.is_overridden('validation_step') and self.limit_val_batches > 0)
return val_loop_enabled or self.fast_dev_run

@property
def default_root_dir(self) -> str:
"""
The default location to save artifacts of loggers, checkpoints etc.
It is used as a fallback if logger or checkpoint callback do not define specific save paths.
Defaults to ``os.getcwd()``.
"""
path = self._default_root_dir or os.getcwd()
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
path = os.path.normpath(path)
return path

@property
def weights_save_path(self) -> str:
"""
The default root location to save weights (checkpoints), e.g., when the
:class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path.
It defaults to :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir`.
"""
path = self._weights_save_path or self.default_root_dir
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
path = os.path.normpath(path)
return path

# -----------------------------
# MODEL TRAINING
# -----------------------------
Expand Down
6 changes: 2 additions & 4 deletions tests/callbacks/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
max_epochs=2,
)
trainer.fit(model)

# These should be different if the dirpath has be overridden
assert trainer.ckpt_path != trainer.default_root_dir
assert checkpoint.dirpath == tmpdir / trainer.logger.name / f'version_0' / 'checkpoints'


@pytest.mark.parametrize(
Expand All @@ -51,7 +49,7 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
)
trainer.fit(model)

ckpt_version = Path(trainer.ckpt_path).parent.name
ckpt_version = Path(trainer.checkpoint_callback.dirpath).parent.name
assert ckpt_version == expected


Expand Down
63 changes: 63 additions & 0 deletions tests/loggers/test_all.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import atexit
import inspect
import os
import pickle
import platform
from unittest import mock
Expand Down Expand Up @@ -82,6 +83,68 @@ def log_metrics(self, metrics, step):
(1, ['epoch', 'test_acc', 'test_loss'])]


@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
MLFlowLogger,
TestTubeLogger,
WandbLogger,
])
@mock.patch('pytorch_lightning.loggers.wandb.wandb')
def test_loggers_save_dir_and_weights_save_path(wandb, tmpdir, monkeypatch, logger_class):
""" Test the combinations of save_dir, weights_save_path and default_root_dir. """
if logger_class == CometLogger:
# prevent comet logger from trying to print at exit, since
# pytest's stdout/stderr redirection breaks it
monkeypatch.setattr(atexit, 'register', lambda _: None)

class TestLogger(logger_class):
# for this test it does not matter what these attributes are
# so we standardize them to make testing easier
@property
def version(self):
return 'version'

@property
def name(self):
return 'name'

model = EvalModelTemplate()
trainer_args = dict(
default_root_dir=tmpdir,
max_steps=1,
)

# no weights_save_path given
save_dir = tmpdir / 'logs'
weights_save_path = None
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == trainer.default_root_dir
assert trainer.checkpoint_callback.dirpath == os.path.join(logger.save_dir, 'name', 'version', 'checkpoints')
assert trainer.default_root_dir == tmpdir

# with weights_save_path given, the logger path and checkpoint path should be different
save_dir = tmpdir / 'logs'
weights_save_path = tmpdir / 'weights'
logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == weights_save_path
assert trainer.logger.save_dir == save_dir
assert trainer.checkpoint_callback.dirpath == weights_save_path / 'name' / 'version' / 'checkpoints'
assert trainer.default_root_dir == tmpdir

# no logger given
weights_save_path = tmpdir / 'weights'
trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path)
trainer.fit(model)
assert trainer.weights_save_path == weights_save_path
assert trainer.checkpoint_callback.dirpath == weights_save_path / 'checkpoints'
assert trainer.default_root_dir == tmpdir


@pytest.mark.parametrize("logger_class", [
TensorBoardLogger,
CometLogger,
Expand Down
4 changes: 2 additions & 2 deletions tests/loggers/test_comet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch):
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
4 changes: 2 additions & 2 deletions tests/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ def test_mlflow_logger_dirs_creation(tmpdir):
assert set(os.listdir(tmpdir / exp_id)) == {run_id, 'meta.yaml'}
assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics')
assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys()
assert trainer.ckpt_path == trainer.weights_save_path == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
4 changes: 2 additions & 2 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3)
trainer.fit(model)

assert trainer.ckpt_path == trainer.weights_save_path == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.ckpt_path)) == {'epoch=0.ckpt'}
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
1 change: 0 additions & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def run_test_from_config(trainer_options):
if trainer.global_rank > 0:
# on higher ranks the checkpoint location is unknown
# we want to test checkpointing on rank 0 only
assert not hasattr(trainer, 'ckpt_path')
assert not trainer.checkpoint_callback.best_model_path
return

Expand Down