Skip to content

Commit

Permalink
Continue Jeremy's early stopping PR #1504 (#2391)
Browse files Browse the repository at this point in the history
* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* cannot pass an int as default_save_path

* refactor log message

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* add state_dict for early stopping

* move best attr after monitor_op defined

* improve early stopping and model checkpoint callbacks

* fix formatting

* fix attr init order

* clean up setting of default_root_dir attr

* logger needs default root dir set first

* reorg trainer init

* remove direct references to checkpoint callback

* more fixes

* more bugfixes

* run callbacks at epoch end

* update tests to use on epoch end

* PR cleanup

* address failing tests

* refactor for homogeneity

* fix merge conflict

* separate tests

* tests for early stopping bug regressions

* small fixes

* revert model checkpoint change

* typo fix

* fix tests

* update train loop

* fix test case

* appease the linter

* fix some doctests

* move config to callback

* fixes from rebase

* fixes from rebase

* chlog

* docs

* reformat

* formatting

* fix

* fix

* fixes from rebase

* add new test for patience

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/callbacks/test_early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* fix formatting

* remove enable_early_stop attribute

* fix test with new epoch indexing

* fix progress bar totals

* fix off by one error (see #2289) epoch starts at 0 now

* added missing imports

* fix hpc_save folderpath

* fix formatting

* fix tests

* small fixes from a rebase

* fix

* tmpdir

* tmpdir

* tmpdir

* wandb

* fix merge conflict

* add back evaluation after training

* test_resume_early_stopping_from_checkpoint TODO

* undo the horovod check

* update changelog

* remove a duplicate test from merge error

* try fix dp_resume test

* add the logger fix from master

* try remove default_root_dir

* try mocking numpy

* try import numpy in docs test

* fix wandb test

* pep 8 fix

* skip if no amp

* dont mock when doctesting

* install extra

* fix the resume ES test

* undo conf.py changes

* revert remove comet pickle from test

* Update CHANGELOG.md

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update weights_loading.rst

* Update weights_loading.rst

* Update weights_loading.rst

* renamed flag

* renamed flag

* revert the None check in logger experiment name/version

* add the old comments

* _experiment

* test chckpointing on DDP

* skip the ddp test on windows

* cloudpickle

* renamed flag

* renamed flag

* parentheses for clarity

* apply suggestion max epochs

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu>
Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
6 people committed Jun 29, 2020
1 parent 1e16681 commit 25ee51b
Show file tree
Hide file tree
Showing 32 changed files with 532 additions and 230 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387))

- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391))

- Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405))

- Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403))
Expand Down
13 changes: 13 additions & 0 deletions docs/source/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ We successfully extended functionality without polluting our super clean

----------------

Best Practices
==============

1. Callbacks should be isolated in their functionality. Your callback should not rely on the
behavior of other callbacks in order to work properly.
2. Do not manually call methods from the callback. The callbacks are designed to be
invoked at specific times during training. Directly calling methods (eg. `on_validation_end`)
is strongly discouraged.
3. Whenever possible, your callbacks should not depend on the order in which they are executed.


---------

.. automodule:: pytorch_lightning.callbacks.base
:noindex:
:exclude-members:
Expand Down
3 changes: 2 additions & 1 deletion docs/source/experiment_logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. testcode::

from pytorch_lightning.loggers import NeptuneLogger

neptune_logger = NeptuneLogger(
api_key='ANONYMOUS', # replace with your own
project_name='shared/pytorch-lightning-integration',
Expand Down Expand Up @@ -193,7 +194,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer.
.. testcode::

from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger()
wandb_logger = WandbLogger(offline=True)
trainer = Trainer(logger=wandb_logger)

The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your
Expand Down
2 changes: 1 addition & 1 deletion docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Automatic saving
Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in:

.. testcode::
.. code-block:: python
trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints')
Expand Down
57 changes: 36 additions & 21 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Monitor a validation metric and stop training when it stops improving.
"""
from copy import deepcopy

import numpy as np
import torch
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait = 0
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode

Expand All @@ -76,12 +77,17 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience:
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')

self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

def _validate_condition_metric(self, logs):
"""
Checks that the condition metric for early stopping is good
:param logs:
:return:
Args:
logs: callback metrics from validation output
Return:
True if specified metric is available
"""
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
Expand All @@ -103,39 +109,48 @@ def _validate_condition_metric(self, logs):
def monitor_op(self):
return self.mode_dict[self.mode]

def on_train_start(self, trainer, pl_module):
# Allow instances to be re-used
self.wait = 0
self.stopped_epoch = 0
self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf
def state_dict(self):
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}

def load_state_dict(self, state_dict):
state_dict = deepcopy(state_dict)
self.wait_count = state_dict['wait_count']
self.stopped_epoch = state_dict['stopped_epoch']
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']

def on_sanity_check_end(self, trainer, pl_module):
logs = trainer.callback_metrics
self._validate_condition_metric(logs)

def on_validation_end(self, trainer, pl_module):
return self._run_early_stopping_check(trainer, pl_module)
self._run_early_stopping_check(trainer, pl_module)

def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
stop_training = False
if not self._validate_condition_metric(logs):
return stop_training
return # short circuit if metric not present

current = logs.get(self.monitor)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current)

if self.monitor_op(current - self.min_delta, self.best):
self.best = current
self.wait = 0
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.wait_count += 1
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
stop_training = True
self.on_train_end(trainer, pl_module)

return stop_training
trainer.should_stop = True

def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping')
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')
35 changes: 35 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,41 @@ def format_checkpoint_name(self, epoch, metrics, ver=None):
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
return filepath

def on_train_start(self, trainer, pl_module):
"""
Determine model checkpoint save directory at runtime. References attributes from the
Trainer's logger to determine where to save checkpoints.
"""
if self.dirpath is not None:
return # short circuit

self.filename = '{epoch}'

if trainer.logger is not None and trainer.logger.experiment is not None:
# weights_save_path overrides anything
if getattr(trainer, 'weights_save_path', None) is not None:
save_dir = trainer.weights_save_path
else:
save_dir = (getattr(trainer.logger, 'save_dir', None)
or getattr(trainer.logger, '_save_dir', None)
or trainer.default_root_dir)

version = trainer.logger.version if isinstance(
trainer.logger.version, str) else f'version_{trainer.logger.version}'
ckpt_path = os.path.join(
save_dir,
trainer.logger.name,
version,
"checkpoints"
)
else:
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

self.dirpath = ckpt_path
os.makedirs(self.dirpath, exist_ok=True)
trainer.ckpt_path = ckpt_path
trainer.weights_save_path = ckpt_path

@rank_zero_only
def on_validation_end(self, trainer, pl_module):
# only run on main process
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)

@property
def name(self) -> str:
def name(self) -> Optional[str]:
# don't create an experiment if we don't have one
name = self._experiment.project_name() if self._experiment else None
return name

@property
def version(self) -> str:
def version(self) -> Optional[str]:
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else None
64 changes: 16 additions & 48 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,79 +32,47 @@ def save_checkpoint(self, *args):
def is_overridden(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def configure_checkpoint_callback(self):
def configure_checkpoint_callback(self, checkpoint_callback):
"""
Weight path set in this priority:
Checkpoint_callback's path (if passed in).
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_root_dir
if self.checkpoint_callback:
# init a default one
if self.logger is not None and self.logger.experiment is not None:
save_dir = (getattr(self.logger, 'save_dir', None) or
getattr(self.logger, '_save_dir', None) or
self.default_root_dir)

# weights_save_path overrides anything
if self.weights_save_path is not None:
save_dir = self.weights_save_path

version = self.logger.version if isinstance(
self.logger.version, str) else f'version_{self.logger.version}'
ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints")
else:
ckpt_path = os.path.join(self.default_root_dir, "checkpoints")

if checkpoint_callback is True:
# when no val step is defined, use 'loss' otherwise 'val_loss'
train_step_only = not self.is_overridden('validation_step')
monitor_key = 'loss' if train_step_only else 'val_loss'
checkpoint_callback = ModelCheckpoint(
filepath=None,
monitor=monitor_key
)
elif checkpoint_callback is False:
checkpoint_callback = None

if self.checkpoint_callback is True:
os.makedirs(ckpt_path, exist_ok=True)
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path,
monitor=monitor_key
)
# If user specified None in filepath, override with runtime default
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
and self.checkpoint_callback.dirpath is None:
self.checkpoint_callback.dirpath = ckpt_path
self.checkpoint_callback.filename = '{epoch}'
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None

self.ckpt_path = ckpt_path

if self.checkpoint_callback:
# set the path for the callbacks
self.checkpoint_callback.save_function = self.save_checkpoint

# if checkpoint callback used, then override the weights path
self.weights_save_path = self.checkpoint_callback.dirpath
if checkpoint_callback:
checkpoint_callback.save_function = self.save_checkpoint

# if weights_save_path is still none here, set to current working dir
if self.weights_save_path is None:
self.weights_save_path = self.default_root_dir

return checkpoint_callback

def configure_early_stopping(self, early_stop_callback):
if early_stop_callback is True or None:
self.early_stop_callback = EarlyStopping(
early_stop_callback = EarlyStopping(
monitor='val_loss',
patience=3,
strict=True,
verbose=True,
mode='min'
)
self.enable_early_stop = True
elif not early_stop_callback:
self.early_stop_callback = None
self.enable_early_stop = False
early_stop_callback = None
else:
self.early_stop_callback = early_stop_callback
self.enable_early_stop = True
early_stop_callback = early_stop_callback
return early_stop_callback

def configure_progress_bar(self, refresh_rate=1, process_position=0):
progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)]
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ class TrainerDDPMixin(ABC):
num_gpu_nodes: int
gpus: List[int]
logger: Union[LightningLoggerBase, bool]
checkpoint_callback: Union[ModelCheckpoint, bool]
data_parallel_device_ids: ...
distributed_backend: Optional[str]
amp_level: str
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def lr_find(self,
# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
self.enable_early_stop = False

# Required for saving the model
self.optimizers, self.schedulers = [], [],
Expand Down Expand Up @@ -215,7 +214,6 @@ def __lr_finder_dump_params(self, model):
'max_steps': self.max_steps,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'configure_optimizers': model.configure_optimizers,
}

Expand All @@ -226,7 +224,6 @@ def __lr_finder_restore_params(self, model):
self.max_steps = self.__dumped_params['max_steps']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params

Expand Down
Loading

0 comments on commit 25ee51b

Please sign in to comment.