From 25ee51bc570503f331dceecc610d0eb355e22327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 29 Jun 2020 03:36:46 +0200 Subject: [PATCH] Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec Co-authored-by: Jeremy Jordan Co-authored-by: Jirka Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec Co-authored-by: William Falcon --- CHANGELOG.md | 2 + docs/source/callbacks.rst | 13 ++ docs/source/experiment_logging.rst | 3 +- docs/source/weights_loading.rst | 2 +- pytorch_lightning/callbacks/early_stopping.py | 57 +++++--- .../callbacks/model_checkpoint.py | 35 +++++ pytorch_lightning/loggers/wandb.py | 4 +- pytorch_lightning/trainer/callback_config.py | 64 +++------ .../trainer/distrib_data_parallel.py | 1 - pytorch_lightning/trainer/lr_finder.py | 3 - pytorch_lightning/trainer/trainer.py | 102 +++++++------- pytorch_lightning/trainer/training_io.py | 37 +++-- pytorch_lightning/trainer/training_loop.py | 83 +++++------ pytorch_lightning/trainer/training_tricks.py | 3 - tests/callbacks/test_callbacks.py | 1 + tests/callbacks/test_early_stopping.py | 133 ++++++++++++++++++ .../{test_lr.py => test_lr_logger.py} | 0 tests/callbacks/test_model_checkpoint.py | 107 ++++++++++++++ tests/callbacks/test_progress_bar.py | 15 +- tests/loggers/test_base.py | 6 +- tests/loggers/test_wandb.py | 17 ++- tests/models/test_amp.py | 3 +- tests/models/test_cpu.py | 5 +- tests/models/test_grad_norm.py | 1 + tests/models/test_hooks.py | 4 +- tests/models/test_restore.py | 1 + tests/trainer/test_dataloaders.py | 9 +- tests/trainer/test_lr_finder.py | 2 +- tests/trainer/test_trainer.py | 22 +-- tests/trainer/test_trainer_cli.py | 6 +- tests/trainer/test_trainer_steps.py | 14 +- tests/trainer/test_trainer_tricks.py | 7 +- 32 files changed, 532 insertions(+), 230 deletions(-) create mode 100644 tests/callbacks/test_early_stopping.py rename tests/callbacks/{test_lr.py => test_lr_logger.py} (100%) create mode 100644 tests/callbacks/test_model_checkpoint.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e1f460940cd8e..e8e9a09599ab2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed loading model with kwargs ([#2387](https://github.com/PyTorchLightning/pytorch-lightning/pull/2387)) +- Fixed several issues with early stopping and checkpoint callbacks ([#1504](https://github.com/PyTorchLightning/pytorch-lightning/pull/1504), [#2391](https://github.com/PyTorchLightning/pytorch-lightning/pull/2391)) + - Fixed loading past checkpoints from v0.7.x ([#2405](https://github.com/PyTorchLightning/pytorch-lightning/pull/2405)) - Fixed loading model without arguments ([#2403](https://github.com/PyTorchLightning/pytorch-lightning/pull/2403)) diff --git a/docs/source/callbacks.rst b/docs/source/callbacks.rst index 39612cfce6da5..f9fcecf880384 100644 --- a/docs/source/callbacks.rst +++ b/docs/source/callbacks.rst @@ -48,6 +48,19 @@ We successfully extended functionality without polluting our super clean ---------------- +Best Practices +============== + +1. Callbacks should be isolated in their functionality. Your callback should not rely on the +behavior of other callbacks in order to work properly. +2. Do not manually call methods from the callback. The callbacks are designed to be +invoked at specific times during training. Directly calling methods (eg. `on_validation_end`) +is strongly discouraged. +3. Whenever possible, your callbacks should not depend on the order in which they are executed. + + +--------- + .. automodule:: pytorch_lightning.callbacks.base :noindex: :exclude-members: diff --git a/docs/source/experiment_logging.rst b/docs/source/experiment_logging.rst index 4ca96a2eee495..6dab549521624 100644 --- a/docs/source/experiment_logging.rst +++ b/docs/source/experiment_logging.rst @@ -92,6 +92,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. testcode:: from pytorch_lightning.loggers import NeptuneLogger + neptune_logger = NeptuneLogger( api_key='ANONYMOUS', # replace with your own project_name='shared/pytorch-lightning-integration', @@ -193,7 +194,7 @@ Then configure the logger and pass it to the :class:`~pytorch_lightning.trainer. .. testcode:: from pytorch_lightning.loggers import WandbLogger - wandb_logger = WandbLogger() + wandb_logger = WandbLogger(offline=True) trainer = Trainer(logger=wandb_logger) The :class:`~pytorch_lightning.loggers.WandbLogger` is available anywhere except ``__init__`` in your diff --git a/docs/source/weights_loading.rst b/docs/source/weights_loading.rst index 4df6ee44bd8f9..067cb380b82da 100644 --- a/docs/source/weights_loading.rst +++ b/docs/source/weights_loading.rst @@ -29,7 +29,7 @@ Automatic saving Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in: -.. testcode:: +.. code-block:: python trainer = Trainer(default_root_dir='/your/path/to/save/checkpoints') diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 99a9bb073d787..b4113ebeccfca 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -5,6 +5,7 @@ Monitor a validation metric and stop training when it stops improving. """ +from copy import deepcopy import numpy as np import torch @@ -58,7 +59,7 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: self.verbose = verbose self.strict = strict self.min_delta = min_delta - self.wait = 0 + self.wait_count = 0 self.stopped_epoch = 0 self.mode = mode @@ -76,12 +77,17 @@ def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.') self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf def _validate_condition_metric(self, logs): """ Checks that the condition metric for early stopping is good - :param logs: - :return: + + Args: + logs: callback metrics from validation output + + Return: + True if specified metric is available """ monitor_val = logs.get(self.monitor) error_msg = (f'Early stopping conditioned on metric `{self.monitor}`' @@ -103,39 +109,48 @@ def _validate_condition_metric(self, logs): def monitor_op(self): return self.mode_dict[self.mode] - def on_train_start(self, trainer, pl_module): - # Allow instances to be re-used - self.wait = 0 - self.stopped_epoch = 0 - self.best = torch_inf if self.monitor_op == torch.lt else -torch_inf + def state_dict(self): + return { + 'wait_count': self.wait_count, + 'stopped_epoch': self.stopped_epoch, + 'best_score': self.best_score, + 'patience': self.patience + } + + def load_state_dict(self, state_dict): + state_dict = deepcopy(state_dict) + self.wait_count = state_dict['wait_count'] + self.stopped_epoch = state_dict['stopped_epoch'] + self.best_score = state_dict['best_score'] + self.patience = state_dict['patience'] + + def on_sanity_check_end(self, trainer, pl_module): + logs = trainer.callback_metrics + self._validate_condition_metric(logs) def on_validation_end(self, trainer, pl_module): - return self._run_early_stopping_check(trainer, pl_module) + self._run_early_stopping_check(trainer, pl_module) def _run_early_stopping_check(self, trainer, pl_module): logs = trainer.callback_metrics - stop_training = False if not self._validate_condition_metric(logs): - return stop_training + return # short circuit if metric not present current = logs.get(self.monitor) if not isinstance(current, torch.Tensor): current = torch.tensor(current) - if self.monitor_op(current - self.min_delta, self.best): - self.best = current - self.wait = 0 + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 else: - self.wait += 1 - if self.wait >= self.patience: + self.wait_count += 1 + if self.wait_count >= self.patience: self.stopped_epoch = trainer.current_epoch - stop_training = True - self.on_train_end(trainer, pl_module) - - return stop_training + trainer.should_stop = True def on_train_end(self, trainer, pl_module): if self.stopped_epoch > 0 and self.verbose > 0: rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' ' but will start from "0" in v0.8.0.', DeprecationWarning) - log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') + log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.') diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6a47e6f58c88a..45e5560e9f288 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -226,6 +226,41 @@ def format_checkpoint_name(self, epoch, metrics, ver=None): filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt') return filepath + def on_train_start(self, trainer, pl_module): + """ + Determine model checkpoint save directory at runtime. References attributes from the + Trainer's logger to determine where to save checkpoints. + """ + if self.dirpath is not None: + return # short circuit + + self.filename = '{epoch}' + + if trainer.logger is not None and trainer.logger.experiment is not None: + # weights_save_path overrides anything + if getattr(trainer, 'weights_save_path', None) is not None: + save_dir = trainer.weights_save_path + else: + save_dir = (getattr(trainer.logger, 'save_dir', None) + or getattr(trainer.logger, '_save_dir', None) + or trainer.default_root_dir) + + version = trainer.logger.version if isinstance( + trainer.logger.version, str) else f'version_{trainer.logger.version}' + ckpt_path = os.path.join( + save_dir, + trainer.logger.name, + version, + "checkpoints" + ) + else: + ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") + + self.dirpath = ckpt_path + os.makedirs(self.dirpath, exist_ok=True) + trainer.ckpt_path = ckpt_path + trainer.weights_save_path = ckpt_path + @rank_zero_only def on_validation_end(self, trainer, pl_module): # only run on main process diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 1b91afaec1e0b..3c8dd5457a22c 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -131,12 +131,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> self.experiment.log({'global_step': step, **metrics} if step is not None else metrics) @property - def name(self) -> str: + def name(self) -> Optional[str]: # don't create an experiment if we don't have one name = self._experiment.project_name() if self._experiment else None return name @property - def version(self) -> str: + def version(self) -> Optional[str]: # don't create an experiment if we don't have one return self._experiment.id if self._experiment else None diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 80088db362de4..b65dc37ef8b1b 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -32,79 +32,47 @@ def save_checkpoint(self, *args): def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - def configure_checkpoint_callback(self): + def configure_checkpoint_callback(self, checkpoint_callback): """ Weight path set in this priority: Checkpoint_callback's path (if passed in). User provided weights_saved_path Otherwise use os.getcwd() """ - ckpt_path = self.default_root_dir - if self.checkpoint_callback: - # init a default one - if self.logger is not None and self.logger.experiment is not None: - save_dir = (getattr(self.logger, 'save_dir', None) or - getattr(self.logger, '_save_dir', None) or - self.default_root_dir) - - # weights_save_path overrides anything - if self.weights_save_path is not None: - save_dir = self.weights_save_path - - version = self.logger.version if isinstance( - self.logger.version, str) else f'version_{self.logger.version}' - ckpt_path = os.path.join(save_dir, self.logger.name, version, "checkpoints") - else: - ckpt_path = os.path.join(self.default_root_dir, "checkpoints") - + if checkpoint_callback is True: # when no val step is defined, use 'loss' otherwise 'val_loss' train_step_only = not self.is_overridden('validation_step') monitor_key = 'loss' if train_step_only else 'val_loss' + checkpoint_callback = ModelCheckpoint( + filepath=None, + monitor=monitor_key + ) + elif checkpoint_callback is False: + checkpoint_callback = None - if self.checkpoint_callback is True: - os.makedirs(ckpt_path, exist_ok=True) - self.checkpoint_callback = ModelCheckpoint( - filepath=ckpt_path, - monitor=monitor_key - ) - # If user specified None in filepath, override with runtime default - elif isinstance(self.checkpoint_callback, ModelCheckpoint) \ - and self.checkpoint_callback.dirpath is None: - self.checkpoint_callback.dirpath = ckpt_path - self.checkpoint_callback.filename = '{epoch}' - os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True) - elif self.checkpoint_callback is False: - self.checkpoint_callback = None - - self.ckpt_path = ckpt_path - - if self.checkpoint_callback: - # set the path for the callbacks - self.checkpoint_callback.save_function = self.save_checkpoint - - # if checkpoint callback used, then override the weights path - self.weights_save_path = self.checkpoint_callback.dirpath + if checkpoint_callback: + checkpoint_callback.save_function = self.save_checkpoint # if weights_save_path is still none here, set to current working dir if self.weights_save_path is None: self.weights_save_path = self.default_root_dir + return checkpoint_callback + def configure_early_stopping(self, early_stop_callback): if early_stop_callback is True or None: - self.early_stop_callback = EarlyStopping( + early_stop_callback = EarlyStopping( monitor='val_loss', patience=3, strict=True, verbose=True, mode='min' ) - self.enable_early_stop = True elif not early_stop_callback: - self.early_stop_callback = None - self.enable_early_stop = False + early_stop_callback = None else: - self.early_stop_callback = early_stop_callback - self.enable_early_stop = True + early_stop_callback = early_stop_callback + return early_stop_callback def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 69aa961c9ed0a..ebf93c87e69df 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -172,7 +172,6 @@ class TrainerDDPMixin(ABC): num_gpu_nodes: int gpus: List[int] logger: Union[LightningLoggerBase, bool] - checkpoint_callback: Union[ModelCheckpoint, bool] data_parallel_device_ids: ... distributed_backend: Optional[str] amp_level: str diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 96f38c86cb939..72228c81394ba 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -163,7 +163,6 @@ def lr_find(self, # Disable standard checkpoint & early stopping self.checkpoint_callback = False self.early_stop_callback = None - self.enable_early_stop = False # Required for saving the model self.optimizers, self.schedulers = [], [], @@ -215,7 +214,6 @@ def __lr_finder_dump_params(self, model): 'max_steps': self.max_steps, 'checkpoint_callback': self.checkpoint_callback, 'early_stop_callback': self.early_stop_callback, - 'enable_early_stop': self.enable_early_stop, 'configure_optimizers': model.configure_optimizers, } @@ -226,7 +224,6 @@ def __lr_finder_restore_params(self, model): self.max_steps = self.__dumped_params['max_steps'] self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] self.early_stop_callback = self.__dumped_params['early_stop_callback'] - self.enable_early_stop = self.__dumped_params['enable_early_stop'] model.configure_optimizers = self.__dumped_params['configure_optimizers'] del self.__dumped_params diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 234faf15923fc..7c8a89bffb87a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -328,9 +328,60 @@ def __init__( if 'LOCAL_RANK' in os.environ: rank_zero_only.rank = os.environ['LOCAL_RANK'] - # Init callbacks + # training bookeeping + self.total_batch_idx = 0 + self.running_loss = TensorRunningAccum(window_length=20) + self.batch_idx = 0 + self.progress_bar_metrics = {} + self.callback_metrics = {} + self.num_training_batches = 0 + self.num_val_batches = [] + self.num_test_batches = [] + self.train_dataloader = None + self.test_dataloaders = None + self.val_dataloaders = None + + # training state + self.model = None + self.testing = False + self.disable_validation = False self.prepare_data_per_node = prepare_data_per_node + self.lr_schedulers = [] + self.optimizers = None + self.optimizer_frequencies = [] + self.global_step = 0 + self.current_epoch = 0 + self.interrupted = False + self.should_stop = False + + # set default save path if user didn't provide one + if default_root_dir is None: + default_root_dir = os.getcwd() + self.default_root_dir = default_root_dir + + self.configure_logger(logger) + + # init callbacks self.callbacks = callbacks or [] + + # configure early stop callback + # creates a default one if none passed in + early_stop_callback = self.configure_early_stopping(early_stop_callback) + if early_stop_callback: + self.callbacks.append(early_stop_callback) + + # configure checkpoint callback + # it is important that this is the last callback to run + # pass through the required args to figure out defaults + self.weights_save_path = weights_save_path + checkpoint_callback = self.configure_checkpoint_callback(checkpoint_callback) + if checkpoint_callback: + self.callbacks.append(checkpoint_callback) + + # TODO refactor codebase (tests) to not directly reach into these callbacks + self.checkpoint_callback = checkpoint_callback + self.early_stop_callback = early_stop_callback + self.on_init_start() # benchmarking @@ -399,52 +450,11 @@ def __init__( rank_zero_info('Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch') - # set default save path if user didn't provide one - self.default_root_dir = default_root_dir - - if self.default_root_dir is None: - self.default_root_dir = os.getcwd() - - # training bookeeping - self.total_batch_idx = 0 - self.running_loss = TensorRunningAccum(window_length=20) - self.batch_idx = 0 - self.progress_bar_metrics = {} - self.callback_metrics = {} - self.num_val_batches = [0] - self.num_training_batches = 0 - self.num_test_batches = [0] - self.train_dataloader = None - self.test_dataloaders = None - self.val_dataloaders = None - - # training state - self.model = None - self.testing = False - self.disable_validation = False - self.lr_schedulers = [] - self.optimizers = None - self.optimizer_frequencies = [] - self.global_step = 0 - self.current_epoch = 0 - self.interrupted = False - - # configure logger - self.configure_logger(logger) - # configure profiler if profiler is True: profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() - # configure early stop callback - # creates a default one if none passed in - self.configure_early_stopping(early_stop_callback) - - # configure checkpoint callback - self.checkpoint_callback = checkpoint_callback - self.weights_save_path = weights_save_path - # accumulated grads self.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) @@ -1045,9 +1055,6 @@ def run_pretrain_routine(self, model: LightningModule): # if cluster resets state, the model will update with the saved weights self.model = model - # set up checkpoint callback - self.configure_checkpoint_callback() - # restore training and model before hpc call self.restore_weights(model) @@ -1078,13 +1085,10 @@ def run_pretrain_routine(self, model: LightningModule): max_batches, False) _, _, _, callback_metrics, _ = self.process_output(eval_results) + self.callback_metrics = callback_metrics self.on_sanity_check_end() - # verify that early stop has conditioned on a metric that exists - if self.enable_early_stop: - self.early_stop_callback._validate_condition_metric(callback_metrics) - # clear cache before training if self.on_gpu and self.root_gpu is not None: # use context because of: diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 4c672743dc4c2..bf421bd618d8a 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -95,6 +95,7 @@ import pytorch_lightning from pytorch_lightning import _logger as log from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, @@ -328,26 +329,32 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: } if not weights_only: - if self.checkpoint_callback: + + # TODO support more generic way for callbacks to persist a state_dict in a checkpoint + checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)] + + if checkpoint_callbacks: + # we add the official checkpoint callback to the end of the list + # extra user provided callbacks will not be persisted yet checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path - if self.early_stop_callback: - checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait - checkpoint['early_stop_callback_patience'] = self.early_stop_callback.patience + if early_stopping_callbacks and checkpoint_callbacks: + # we add the official early stopping callback to the end of the list + # extra user provided callbacks will not be persisted yet + checkpoint['early_stop_callback_state_dict'] = early_stopping_callbacks[-1].state_dict() # save optimizers optimizer_states = [] for i, optimizer in enumerate(self.optimizers): optimizer_states.append(optimizer.state_dict()) - checkpoint['optimizer_states'] = optimizer_states # save lr schedulers lr_schedulers = [] for scheduler in self.lr_schedulers: lr_schedulers.append(scheduler['scheduler'].state_dict()) - checkpoint['lr_schedulers'] = lr_schedulers # save native amp scaling @@ -405,21 +412,25 @@ def restore_training_state(self, checkpoint): ' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.' ) - if self.checkpoint_callback: + # TODO support more generic way for callbacks to load callback state_dicts + checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + early_stopping_callbacks = [c for c in self.callbacks if isinstance(c, EarlyStopping)] + + if checkpoint_callbacks: if 'checkpoint_callback_best_model_score' in checkpoint: - self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score'] + checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best_model_score'] else: # Old naming until version 0.7.6 rank_zero_warn( 'Loading a checkpoint created with an old version of Lightning; ' 'this will not be supported in the future.' ) - self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best'] - self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path'] + checkpoint_callbacks[-1].best_model_score = checkpoint['checkpoint_callback_best'] + checkpoint_callbacks[-1].best_model_path = checkpoint['checkpoint_callback_best_model_path'] - if self.early_stop_callback: - self.early_stop_callback.wait = checkpoint['early_stop_callback_wait'] - self.early_stop_callback.patience = checkpoint['early_stop_callback_patience'] + if early_stopping_callbacks: + state = checkpoint['early_stop_callback_state_dict'] + early_stopping_callbacks[-1].load_state_dict(state) self.global_step = checkpoint['global_step'] self.current_epoch = checkpoint['epoch'] diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4cf5e61b5c7eb..be0735701850a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -144,8 +144,7 @@ def training_step(self, batch, batch_idx): """ -import atexit -import signal +import subprocess from abc import ABC, abstractmethod from typing import Callable from typing import Union, List @@ -157,6 +156,7 @@ def training_step(self, batch, batch_idx): from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.trainer.supporters import TensorRunningAccum @@ -164,7 +164,6 @@ def training_step(self, batch, batch_idx): from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.memory import recursive_detach -import subprocess try: from apex import amp @@ -212,7 +211,6 @@ class TrainerTrainLoopMixin(ABC): fast_dev_run: ... accumulation_scheduler: ... lr_schedulers: ... - enable_early_stop: ... early_stop_callback: ... callback_metrics: ... logger: Union[LightningLoggerBase, bool] @@ -239,7 +237,6 @@ class TrainerTrainLoopMixin(ABC): max_steps: int min_steps: int total_batch_idx: int - checkpoint_callback: ... terminate_on_nan: bool tpu_id: int interactive_ddp_procs: ... @@ -264,7 +261,7 @@ def is_function_implemented(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod - def run_evaluation(self, *args): + def run_evaluation(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod @@ -340,9 +337,6 @@ def train(self): with self.profiler.profile('on_train_start'): # callbacks self.on_train_start() - # initialize early stop callback - if self.early_stop_callback is not None: - self.early_stop_callback.on_train_start(self, self.get_model()) # model hooks model.on_train_start() @@ -375,7 +369,7 @@ def train(self): # ----------------- self.run_training_epoch() - if self.max_steps and self.max_steps == self.global_step: + if self.max_steps and self.max_steps <= self.global_step: self.run_training_teardown() return @@ -386,19 +380,14 @@ def train(self): met_min_epochs = epoch >= self.min_epochs - 1 met_min_steps = self.global_step >= self.min_steps if self.min_steps else True - # TODO wrap this logic into the callback - # DO NOT DELETE - # early stopping as a (new Callback) class doesn't yet work because we have to know these - # trainer flags including the current epoch stuff - # all of this needs to go into the early stopping to clean up better - if self.enable_early_stop: + if self.should_stop: if (met_min_epochs and met_min_steps) or self.fast_dev_run: - should_stop = self.early_stop_callback.on_validation_end(self, self.get_model()) - # stop training - stop = should_stop and met_min_epochs - if stop: - self.run_training_teardown() - return + self.run_training_teardown() + return + else: + log.info('Trainer was signaled to stop but required minimum epochs' + f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' + ' not been met. Training will continue...') self.run_training_teardown() @@ -444,6 +433,7 @@ def run_training_epoch(self): # bookkeeping epoch_output = [] + should_check_val = False # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( @@ -470,22 +460,24 @@ def run_training_epoch(self): self.update_train_loop_lr_schedulers() # when returning -1 from train_step, we end epoch early - early_stop_epoch = batch_output.signal == -1 + self.should_stop = batch_output.signal == -1 # ----------------------------------------- # VALIDATE IF NEEDED + CHECKPOINT CALLBACK # ----------------------------------------- - should_check_val = self.check_validation_in_train_loop(batch_idx, early_stop_epoch, is_last_batch) + should_check_val = self.should_check_val(batch_idx, is_last_batch) + if self.fast_dev_run or should_check_val: + self.run_evaluation(test_mode=False) # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) # ----------------------------------------- - self.save_loggers_in_training_loop(batch_idx, early_stop_epoch) + self.save_loggers_in_training_loop(batch_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS # ----------------------------------------- - self.save_train_loop_metrics_to_loggers(batch_idx, early_stop_epoch, batch_output) + self.save_train_loop_metrics_to_loggers(batch_idx, batch_output) # progress global step according to grads progress self.increment_accumulated_grad_global_step() @@ -497,7 +489,7 @@ def run_training_epoch(self): # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches - if early_stop_epoch or self.fast_dev_run: + if self.fast_dev_run or self.should_stop: break # let ddp devices catch up when using horovod @@ -506,13 +498,19 @@ def run_training_epoch(self): # process epoch outputs self.run_training_epoch_end(epoch_output) - # when no val loop is present or fast-dev-run still need to call checkpoints - if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): - self.call_checkpoint_callback() + # checkpoint callback + self.check_checkpoint_callback(should_check_val) # epoch end hook self.run_on_epoch_end_hook(model) + def check_checkpoint_callback(self, should_check_val): + # when no val loop is present or fast-dev-run still need to call checkpoints + # TODO bake this logic into the checkpoint callback + if not self.is_overridden('validation_step') and not (self.fast_dev_run or should_check_val): + checkpoint_callbacks = [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + [c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks] + def update_train_loop_lr_schedulers(self): if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: # update lr @@ -553,33 +551,28 @@ def increment_accumulated_grad_global_step(self): self.global_step += 1 self.total_batch_idx += 1 - def save_train_loop_metrics_to_loggers(self, batch_idx, early_stop_epoch, batch_output): + def save_train_loop_metrics_to_loggers(self, batch_idx, batch_output): # when metrics should be logged - should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch + should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop if should_log_metrics or self.fast_dev_run: # logs user requested information to logger self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) - def save_loggers_in_training_loop(self, batch_idx, early_stop_epoch): + def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk - should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch + should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or self.should_stop if should_save_log or self.fast_dev_run: if self.is_global_zero and self.logger is not None: self.logger.save() - def check_validation_in_train_loop(self, batch_idx, early_stop_epoch, is_last_batch): + def should_check_val(self, batch_idx, is_last_batch): # decide if we should run validation is_val_check_batch = (batch_idx + 1) % self.val_check_batch == 0 can_check_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 can_check_val = not self.disable_validation and can_check_epoch - should_check_val = is_val_check_batch or early_stop_epoch - should_check_val = should_check_val or (is_last_batch and self.val_check_batch == float('inf')) - should_check_val = can_check_val and should_check_val - - # if we need to run validation, then also call the checkpoint callback - if self.fast_dev_run or should_check_val: - self.run_evaluation(test_mode=self.testing) - self.call_checkpoint_callback() + should_check_val = is_val_check_batch or self.should_stop + is_last_batch_for_infinite_dataset = (is_last_batch and self.val_check_batch == float('inf')) + should_check_val = can_check_val and (should_check_val or is_last_batch_for_infinite_dataset) return should_check_val @@ -984,10 +977,6 @@ def update_learning_rates(self, interval: str): else: lr_scheduler['scheduler'].step() - def call_checkpoint_callback(self): - if self.checkpoint_callback is not None: - self.checkpoint_callback.on_validation_end(self, self.get_model()) - def _with_is_last(iterable): """Pass through values from the given iterable with an added boolean indicating if this is the last item. diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index fb4262572bcdd..69c764e06e48f 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -188,7 +188,6 @@ def __scale_batch_dump_params(self): 'callbacks': self.callbacks, 'checkpoint_callback': self.checkpoint_callback, 'early_stop_callback': self.early_stop_callback, - 'enable_early_stop': self.enable_early_stop, 'auto_scale_batch_size': self.auto_scale_batch_size, 'limit_train_batches': self.limit_train_batches, 'model': self.model, @@ -202,7 +201,6 @@ def __scale_batch_reset_params(self, model, steps_per_trial): self.callbacks = [] # not needed before full run self.checkpoint_callback = False # required for saving self.early_stop_callback = None - self.enable_early_stop = False self.limit_train_batches = 1.0 self.optimizers, self.schedulers = [], [] # required for saving self.model = model # required for saving @@ -215,7 +213,6 @@ def __scale_batch_restore_params(self): self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size'] self.early_stop_callback = self.__dumped_params['early_stop_callback'] - self.enable_early_stop = self.__dumped_params['enable_early_stop'] self.limit_train_batches = self.__dumped_params['limit_train_batches'] self.model = self.__dumped_params['model'] del self.__dumped_params diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 93794b56598ce..b1034ef7d7f28 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -160,6 +160,7 @@ def on_test_end(self, trainer, pl_module): test_callback = TestCallback() trainer_options = dict( + default_root_dir=tmpdir, callbacks=[test_callback], max_epochs=1, limit_val_batches=0.1, diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py new file mode 100644 index 0000000000000..c3e5fa3914682 --- /dev/null +++ b/tests/callbacks/test_early_stopping.py @@ -0,0 +1,133 @@ +import pickle + +import cloudpickle +import pytest + +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from tests.base import EvalModelTemplate + + +def test_resume_early_stopping_from_checkpoint(tmpdir): + """ + Prevent regressions to bugs: + https://github.com/PyTorchLightning/pytorch-lightning/issues/1464 + https://github.com/PyTorchLightning/pytorch-lightning/issues/1463 + """ + + class EarlyStoppingTestStore(EarlyStopping): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # cache the state for each epoch + self.saved_states = [] + + def on_validation_end(self, trainer, pl_module): + super().on_validation_end(trainer, pl_module) + self.saved_states.append(self.state_dict().copy()) + + class EarlyStoppingTestRestore(EarlyStopping): + def __init__(self, expected_state): + super().__init__() + self.expected_state = expected_state + + def on_train_start(self, trainer, pl_module): + assert self.state_dict() == self.expected_state + + model = EvalModelTemplate() + checkpoint_callback = ModelCheckpoint(save_top_k=1) + early_stop_callback = EarlyStoppingTestStore() + trainer = Trainer( + default_root_dir=tmpdir, + checkpoint_callback=checkpoint_callback, + early_stop_callback=early_stop_callback, + max_epochs=4, + ) + trainer.fit(model) + + checkpoint_filepath = checkpoint_callback.kth_best_model + # ensure state is persisted properly + checkpoint = torch.load(checkpoint_filepath) + # the checkpoint saves "epoch + 1" + early_stop_callback_state = early_stop_callback.saved_states[checkpoint['epoch'] - 1] + assert 4 == len(early_stop_callback.saved_states) + assert checkpoint['early_stop_callback_state_dict'] == early_stop_callback_state + + # ensure state is reloaded properly (assertion in the callback) + early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state) + new_trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + resume_from_checkpoint=checkpoint_filepath, + early_stop_callback=early_stop_callback, + ) + new_trainer.fit(model) + + +def test_early_stopping_no_extraneous_invocations(tmpdir): + """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" + class EarlyStoppingTestInvocations(EarlyStopping): + def __init__(self, expected_count): + super().__init__() + self.count = 0 + self.expected_count = expected_count + + def on_validation_end(self, trainer, pl_module): + self.count += 1 + + def on_train_end(self, trainer, pl_module): + assert self.count == self.expected_count + + model = EvalModelTemplate() + expected_count = 4 + early_stop_callback = EarlyStoppingTestInvocations(expected_count) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=early_stop_callback, + val_check_interval=1.0, + max_epochs=expected_count, + ) + trainer.fit(model) + + +@pytest.mark.parametrize('loss_values, patience, expected_stop_epoch', [ + ([6, 5, 5, 5, 5, 5], 3, 4), + ([6, 5, 4, 4, 3, 3], 1, 3), + ([6, 5, 6, 5, 5, 5], 3, 4), +]) +def test_early_stopping_patience(tmpdir, loss_values, patience, expected_stop_epoch): + """Test to ensure that early stopping is not triggered before patience is exhausted.""" + + class ModelOverrideValidationReturn(EvalModelTemplate): + validation_return_values = torch.Tensor(loss_values) + count = 0 + + def validation_epoch_end(self, outputs): + loss = self.validation_return_values[self.count] + self.count += 1 + return {"test_val_loss": loss} + + model = ModelOverrideValidationReturn() + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=early_stop_callback, + val_check_interval=1.0, + num_sanity_val_steps=0, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + + +def test_pickling(tmpdir): + early_stopping = EarlyStopping() + + early_stopping_pickled = pickle.dumps(early_stopping) + early_stopping_loaded = pickle.loads(early_stopping_pickled) + assert vars(early_stopping) == vars(early_stopping_loaded) + + early_stopping_pickled = cloudpickle.dumps(early_stopping) + early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) + assert vars(early_stopping) == vars(early_stopping_loaded) + diff --git a/tests/callbacks/test_lr.py b/tests/callbacks/test_lr_logger.py similarity index 100% rename from tests/callbacks/test_lr.py rename to tests/callbacks/test_lr_logger.py diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py new file mode 100644 index 0000000000000..b5cb7ca0c756e --- /dev/null +++ b/tests/callbacks/test_model_checkpoint.py @@ -0,0 +1,107 @@ +import os +import pickle +import platform +from pathlib import Path + +import cloudpickle +import pytest + +import tests.base.develop_utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from tests.base import EvalModelTemplate + + +@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2]) +def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k): + """ Test that None in checkpoint callback is valid and that chkp_path is set correctly """ + tutils.reset_seed() + model = EvalModelTemplate() + + checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k) + + trainer = Trainer( + default_root_dir=tmpdir, + checkpoint_callback=checkpoint, + overfit_pct=0.20, + max_epochs=(save_top_k + 2), + ) + trainer.fit(model) + + # These should be different if the dirpath has be overridden + assert trainer.ckpt_path != trainer.default_root_dir + + +@pytest.mark.parametrize( + 'logger_version,expected', + [(None, 'version_0'), (1, 'version_1'), ('awesome', 'awesome')], +) +def test_model_checkpoint_path(tmpdir, logger_version, expected): + """Test that "version_" prefix is only added when logger's version is an integer""" + tutils.reset_seed() + model = EvalModelTemplate() + logger = TensorBoardLogger(str(tmpdir), version=logger_version) + + trainer = Trainer( + default_root_dir=tmpdir, + overfit_pct=0.2, + max_epochs=5, + logger=logger, + ) + trainer.fit(model) + + ckpt_version = Path(trainer.ckpt_path).parent.name + assert ckpt_version == expected + + +def test_pickling(tmpdir): + ckpt = ModelCheckpoint(tmpdir) + + ckpt_pickled = pickle.dumps(ckpt) + ckpt_loaded = pickle.loads(ckpt_pickled) + assert vars(ckpt) == vars(ckpt_loaded) + + ckpt_pickled = cloudpickle.dumps(ckpt) + ckpt_loaded = cloudpickle.loads(ckpt_pickled) + assert vars(ckpt) == vars(ckpt_loaded) + + +class ModelCheckpointTestInvocations(ModelCheckpoint): + # this class has to be defined outside the test function, otherwise we get pickle error + # due to the way ddp process is launched + + def __init__(self, expected_count, *args, **kwargs): + super().__init__(*args, **kwargs) + self.count = 0 + self.expected_count = expected_count + + def _save_model(self, filepath): + # make sure we don't save twice + assert not os.path.isfile(filepath) + self.count += 1 + super()._save_model(filepath) + + def on_train_end(self, trainer, pl_module): + super().on_train_end(trainer, pl_module) + # on rank 0 we expect the saved files and on all others no saves + assert (trainer.global_rank == 0 and self.count == self.expected_count) \ + or (trainer.global_rank > 0 and self.count == 0) + + +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +def test_model_checkpoint_no_extraneous_invocations(tmpdir): + """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" + model = EvalModelTemplate() + num_epochs = 4 + model_checkpoint = ModelCheckpointTestInvocations(expected_count=num_epochs, save_top_k=-1) + trainer = Trainer( + distributed_backend='ddp_cpu', + num_processes=2, + default_root_dir=tmpdir, + early_stop_callback=False, + checkpoint_callback=model_checkpoint, + max_epochs=num_epochs, + ) + result = trainer.fit(model) + assert 1 == result diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index a63fc62585c45..f621e70228012 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -13,10 +13,11 @@ ([ProgressBar(refresh_rate=2)], 0), ([ProgressBar(refresh_rate=2)], 1), ]) -def test_progress_bar_on(callbacks, refresh_rate): +def test_progress_bar_on(tmpdir, callbacks, refresh_rate): """Test different ways the progress bar can be turned on.""" trainer = Trainer( + default_root_dir=tmpdir, callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, max_epochs=1, @@ -34,10 +35,11 @@ def test_progress_bar_on(callbacks, refresh_rate): ([], False), ([ModelCheckpoint('../trainer')], 0), ]) -def test_progress_bar_off(callbacks, refresh_rate): +def test_progress_bar_off(tmpdir, callbacks, refresh_rate): """Test different ways the progress bar can be turned off.""" trainer = Trainer( + default_root_dir=tmpdir, callbacks=callbacks, progress_bar_refresh_rate=refresh_rate, ) @@ -54,12 +56,13 @@ def test_progress_bar_misconfiguration(): Trainer(callbacks=callbacks) -def test_progress_bar_totals(): +def test_progress_bar_totals(tmpdir): """Test that the progress finishes with the correct total steps processed.""" model = EvalModelTemplate() trainer = Trainer( + default_root_dir=tmpdir, progress_bar_refresh_rate=1, limit_val_batches=1.0, max_epochs=1, @@ -105,10 +108,11 @@ def test_progress_bar_totals(): assert bar.test_batch_idx == k -def test_progress_bar_fast_dev_run(): +def test_progress_bar_fast_dev_run(tmpdir): model = EvalModelTemplate() trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, ) @@ -136,7 +140,7 @@ def test_progress_bar_fast_dev_run(): @pytest.mark.parametrize('refresh_rate', [0, 1, 50]) -def test_progress_bar_progress_refresh(refresh_rate): +def test_progress_bar_progress_refresh(tmpdir, refresh_rate): """Test that the three progress bars get correctly updated when using different refresh rates.""" model = EvalModelTemplate() @@ -172,6 +176,7 @@ def on_test_batch_end(self, trainer, pl_module): progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) trainer = Trainer( + default_root_dir=tmpdir, callbacks=[progress_bar], progress_bar_refresh_rate=101, # should not matter if custom callback provided limit_train_batches=1.0, diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index e330e50e581fe..dfe9ffc6437fe 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -108,7 +108,11 @@ def test_multiple_loggers_pickle(tmpdir): logger1 = CustomLogger() logger2 = CustomLogger() - trainer = Trainer(max_epochs=1, logger=[logger1, logger2]) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=[logger1, logger2], + ) pkl_bytes = pickle.dumps(trainer) trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}, 0) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 0eb22331f690c..aa8b616bcf475 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -1,12 +1,12 @@ import os import pickle -from unittest.mock import patch +from unittest import mock from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger -@patch('pytorch_lightning.loggers.wandb.wandb') +@mock.patch('pytorch_lightning.loggers.wandb.wandb') def test_wandb_logger(wandb): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here.""" @@ -29,8 +29,8 @@ def test_wandb_logger(wandb): assert logger.version == wandb.init().id -@patch('pytorch_lightning.loggers.wandb.wandb') -def test_wandb_pickle(wandb): +@mock.patch('pytorch_lightning.loggers.wandb.wandb') +def test_wandb_pickle(wandb, tmpdir): """Verify that pickling trainer with wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. @@ -38,11 +38,18 @@ def test_wandb_pickle(wandb): class Experiment: id = 'the_id' + def project_name(self): + return 'the_project_name' + wandb.init.return_value = Experiment() logger = WandbLogger(id='the_id', offline=True) - trainer = Trainer(max_epochs=1, logger=logger) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + logger=logger, + ) # Access the experiment to ensure it's created assert trainer.logger.experiment, 'missing experiment' pkl_bytes = pickle.dumps(trainer) diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index d8299b89df322..1c187a8188332 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -21,7 +21,7 @@ def test_amp_single_gpu(tmpdir, backend): max_epochs=1, gpus=1, distributed_backend=backend, - precision=16 + precision=16, ) model = EvalModelTemplate() @@ -100,6 +100,7 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): # fit model trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, gpus=[0], distributed_backend='ddp', diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index 808163014adb3..8160bf8c72b44 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -24,6 +24,7 @@ def test_cpu_slurm_save_load(tmpdir): # fit model trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, limit_train_batches=0.2, @@ -54,13 +55,14 @@ def test_cpu_slurm_save_load(tmpdir): # test HPC saving # simulate snapshot on slurm - saved_filepath = trainer.hpc_save(tmpdir, logger) + saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger) assert os.path.exists(saved_filepath) # new logger file to get meta logger = tutils.get_default_logger(tmpdir, version=version) trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, checkpoint_callback=ModelCheckpoint(tmpdir), @@ -212,6 +214,7 @@ def test_running_test_no_val(tmpdir): # fit model trainer = Trainer( + default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=1, limit_train_batches=0.4, diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index 3c0cd9d6c5b97..ff627c5088987 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -84,6 +84,7 @@ def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): logger = OnlyMetricsListLogger() trainer = Trainer( + default_root_dir=tmpdir, max_epochs=3, logger=logger, track_grad_norm=norm_type, diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 78efbd35ff4da..7d5a8849948d6 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -8,7 +8,7 @@ @pytest.mark.parametrize('max_steps', [1, 2, 3]) -def test_on_before_zero_grad_called(max_steps): +def test_on_before_zero_grad_called(tmpdir, max_steps): class CurrentTestModel(EvalModelTemplate): on_before_zero_grad_called = 0 @@ -19,7 +19,9 @@ def on_before_zero_grad(self, optimizer): model = CurrentTestModel() trainer = Trainer( + default_root_dir=tmpdir, max_steps=max_steps, + max_epochs=2, num_sanity_val_steps=5, ) assert 0 == model.on_before_zero_grad_called diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index c77f7a841f3c3..9eb1067322127 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -154,6 +154,7 @@ def test_dp_resume(tmpdir): max_epochs=1, gpus=2, distributed_backend='dp', + default_root_dir=tmpdir, ) # get logger diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 3b44aa69c7164..b36eca8a2e429 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -199,7 +199,7 @@ def test_all_dataloaders_passed_to_fit(tmpdir, ckpt_path): default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, - limit_train_batches=0.2 + limit_train_batches=0.2, ) fit_options = dict(train_dataloader=model.dataloader(train=True), val_dataloaders=model.dataloader(train=False)) @@ -401,7 +401,7 @@ def test_inf_train_dataloader(tmpdir, check_interval): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - val_check_interval=check_interval + val_check_interval=check_interval, ) result = trainer.fit(model) # verify training completed @@ -440,7 +440,7 @@ def test_error_on_zero_len_dataloader(tmpdir): max_epochs=1, limit_train_batches=0.1, limit_val_batches=0.1, - limit_test_batches=0.1 + limit_test_batches=0.1, ) trainer.fit(model) @@ -534,7 +534,7 @@ class CustomSampler(torch.utils.data.Sampler): @pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs') -def test_batch_size_smaller_than_num_gpus(): +def test_batch_size_smaller_than_num_gpus(tmpdir): # we need at least 3 gpus for this test num_gpus = 3 batch_size = 3 @@ -572,6 +572,7 @@ def train_dataloader(self): model = CurrentTestModel(**hparams) trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0, diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 08f97cf6b101c..b9f955ed22331 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -57,7 +57,7 @@ def test_trainer_reset_correctly(tmpdir): changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find', 'early_stop_callback', 'accumulate_grad_batches', - 'enable_early_stop', 'checkpoint_callback'] + 'checkpoint_callback'] attributes_before = {} for ca in changed_attributes: attributes_before[ca] = getattr(trainer, ca) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1a2b3294894f6..68b41d65471b0 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -36,9 +36,10 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): logger = tutils.get_default_logger(tmpdir) trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + checkpoint_callback=ModelCheckpoint(tmpdir), ) # fit model result = trainer.fit(model) @@ -77,9 +78,10 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): # fit model trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, logger=logger, - checkpoint_callback=ModelCheckpoint(tmpdir) + checkpoint_callback=ModelCheckpoint(tmpdir), ) result = trainer.fit(model) @@ -297,8 +299,9 @@ def test_model_checkpoint_only_weights(tmpdir): model = EvalModelTemplate() trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, - checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True) + checkpoint_callback=ModelCheckpoint(tmpdir, save_weights_only=True), ) # fit model result = trainer.fit(model) @@ -469,7 +472,7 @@ def test_trainer_min_steps_and_epochs(tmpdir): early_stop_callback=EarlyStopping(monitor='val_loss', min_delta=1.0), val_check_interval=2, min_epochs=1, - max_epochs=2 + max_epochs=7 ) # define less min steps than 1 epoch @@ -592,7 +595,7 @@ def load_from_checkpoint(cls, checkpoint_path, *args, **kwargs): assert loaded_checkpoint_path == ckpt_path -def test_disabled_validation(): +def test_disabled_validation(tmpdir): """Verify that `limit_val_batches=0` disables the validation loop unless `fast_dev_run=True`.""" class CurrentModel(EvalModelTemplate): @@ -612,6 +615,7 @@ def validation_epoch_end(self, *args, **kwargs): model = CurrentModel(**hparams) trainer_options = dict( + default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=2, limit_train_batches=0.4, @@ -664,7 +668,7 @@ def training_step(self, batch, batch_idx, optimizer_idx=None): trainer = Trainer( default_root_dir=tmpdir, max_steps=(model.test_batch_inf_loss + 1), - terminate_on_nan=True + terminate_on_nan=True, ) with pytest.raises(ValueError, match=r'.*The loss returned in `training_step` is nan or inf.*'): @@ -689,7 +693,7 @@ def on_after_backward(self): trainer = Trainer( default_root_dir=tmpdir, max_steps=(model.test_batch_nan + 1), - terminate_on_nan=True + terminate_on_nan=True, ) with pytest.raises(ValueError, match=r'.*Detected nan and/or inf values in `c_d1.bias`.*'): @@ -757,7 +761,7 @@ def _optimizer_step(*args, **kwargs): max_steps=1, max_epochs=1, gradient_clip_val=1.0, - default_root_dir=tmpdir + default_root_dir=tmpdir, ) # for the test @@ -944,7 +948,7 @@ def test_trainer_omegaconf(trainer_params): def test_trainer_pickle(tmpdir): trainer = Trainer( max_epochs=1, - default_root_dir=tmpdir + default_root_dir=tmpdir, ) pickle.dumps(trainer) cloudpickle.dumps(trainer) diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index c6f2406f1c9a6..51bbc96bd4f9b 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -11,10 +11,10 @@ from pytorch_lightning import Trainer -@mock.patch('argparse.ArgumentParser.parse_args', - return_value=Namespace(**Trainer.default_attributes())) -def test_default_args(tmpdir): +@mock.patch('argparse.ArgumentParser.parse_args') +def test_default_args(mock_argparse, tmpdir): """Tests default argument parser for Trainer""" + mock_argparse.return_value = Namespace(**Trainer.default_attributes()) # logger file to get meta logger = tutils.get_default_logger(tmpdir) diff --git a/tests/trainer/test_trainer_steps.py b/tests/trainer/test_trainer_steps.py index 70302474bedf4..a5ca3c7ab916a 100644 --- a/tests/trainer/test_trainer_steps.py +++ b/tests/trainer/test_trainer_steps.py @@ -2,7 +2,7 @@ from tests.base.deterministic_model import DeterministicModel -def test_trainingstep_dict(tmpdir): +def test_training_step_dict(tmpdir): """ Tests that only training_step can be used """ @@ -10,7 +10,11 @@ def test_trainingstep_dict(tmpdir): model.training_step = model.training_step_dict_return model.val_dataloader = None - trainer = Trainer(fast_dev_run=True, weights_summary=None) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + weights_summary=None, + ) trainer.fit(model) # make sure correct steps were called @@ -74,6 +78,7 @@ def test_full_training_loop_dict(tmpdir): model.val_dataloader = None trainer = Trainer( + default_root_dir=tmpdir, max_epochs=1, weights_summary=None, ) @@ -112,10 +117,7 @@ def test_train_step_epoch_end(tmpdir): model.training_epoch_end = model.training_epoch_end_dict model.val_dataloader = None - trainer = Trainer( - max_epochs=1, - weights_summary=None, - ) + trainer = Trainer(max_epochs=1, weights_summary=None) trainer.fit(model) # make sure correct steps were called diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 32415ffe2e9ac..48a8f9011811c 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -118,7 +118,7 @@ def test_model_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1 + max_epochs=1, ) before_state_dict = model.state_dict() @@ -141,7 +141,7 @@ def test_trainer_reset_correctly(tmpdir): # logger file to get meta trainer = Trainer( default_root_dir=tmpdir, - max_epochs=1 + max_epochs=1, ) changed_attributes = ['max_steps', @@ -150,7 +150,6 @@ def test_trainer_reset_correctly(tmpdir): 'callbacks', 'checkpoint_callback', 'early_stop_callback', - 'enable_early_stop', 'limit_train_batches'] attributes_before = {} @@ -224,7 +223,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir): max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, - auto_scale_batch_size='power' + auto_scale_batch_size='power', ) fit_options = dict(train_dataloader=model.dataloader(train=True))