diff --git a/docs/Trainer/Checkpointing.md b/docs/Trainer/Checkpointing.md index 791e72ebd5955..b6f57e0805c4b 100644 --- a/docs/Trainer/Checkpointing.md +++ b/docs/Trainer/Checkpointing.md @@ -1,6 +1,7 @@ Lightning can automate saving and loading checkpoints. --- + ### Model saving Checkpointing is enabled by default to the current working directory. To change the checkpoint path pass in : @@ -10,13 +11,13 @@ Trainer(default_save_path='/your/path/to/save/checkpoints') To modify the behavior of checkpointing pass in your own callback. -``` {.python} +```{.python} from pytorch_lightning.callbacks import ModelCheckpoint # DEFAULTS used by the Trainer checkpoint_callback = ModelCheckpoint( filepath=os.getcwd(), - save_best_only=True, + save_top_k=1, verbose=True, monitor='val_loss', mode='min', @@ -26,11 +27,26 @@ checkpoint_callback = ModelCheckpoint( trainer = Trainer(checkpoint_callback=checkpoint_callback) ``` +The `save_top_k` options works in the following ways: + +| save_top_k | behavior | +| -------- | ----- | +| 0 | no models are saved | +| -1 | all models are saved | +| k >= 1 | the best k models are saved | + + +Also, if `save_top_k` >= 2 and the callback is called multiple +times inside an epoch, the name of the saved file will be +appended with a version count starting with `v0`. + --- -### Restoring training session + +### Restoring training session + You might want to not only load a model but also continue training it. Use this method to restore the trainer state as well. This will continue from the epoch and global step you last left off. -However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). +However, the dataloaders will start from the first batch again (if you shuffled it shouldn't matter). Lightning will restore the session if you pass a logger with the same version and there's a saved checkpoint. ``` {.python} @@ -52,18 +68,19 @@ trainer = Trainer( trainer.fit(model) ``` -The trainer restores: +The trainer restores: -- global_step -- current_epoch -- All optimizers -- All lr_schedulers +- global_step +- current_epoch +- All optimizers +- All lr_schedulers - Model weights -You can even change the logic of your model as long as the weights and "architecture" of -the system isn't different. If you add a layer, for instance, it might not work. +You can even change the logic of your model as long as the weights and "architecture" of +the system isn't different. If you add a layer, for instance, it might not work. + +At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63): -At a rough level, here's [what happens inside Trainer](https://github.com/williamFalcon/pytorch-lightning/blob/master/pytorch_lightning/root_module/model_saving.py#L63): ```python self.global_step = checkpoint['global_step'] @@ -79,6 +96,6 @@ lr_schedulers = checkpoint['lr_schedulers'] for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): scheduler.load_state_dict(lrs_state) -# uses the model you passed into trainer +# uses the model you passed into trainer model.load_state_dict(checkpoint['state_dict']) -``` +``` diff --git a/docs/examples/Examples.md b/docs/examples/Examples.md index 49e83e7026c86..354fae5a80f2d 100644 --- a/docs/examples/Examples.md +++ b/docs/examples/Examples.md @@ -6,14 +6,15 @@ In 99% of cases you want to just copy [one of the examples](https://github.com/w wget https://raw.githubusercontent.com/williamFalcon/pytorch-lightning/master/pl_examples/new_project_templates/lightning_module_template.py ``` ---- -### Trainer Example +--- + +### Trainer Example -** \_\_main__ function** +** \_\_main\_\_ function** -Normally, we want to let the \_\_main__ function start the training. -Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a -chance to add hyperparameters. +Normally, we want to let the \_\_main\_\_ function start the training. +Inside the main we parse training arguments with whatever hyperparameters we want. Your LightningModule will have a +chance to add hyperparameters. ```{.python} from test_tube import HyperOptArgumentParser @@ -32,13 +33,15 @@ if __name__ == '__main__': # train model main(hyperparams) ``` -**Main Function** + +**Main Function** The main function is your entry into the program. This is where you init your model, checkpoint directory, and launch the training. -The main function should have 3 arguments: -- hparams: a configuration of hyperparameters. +The main function should have 3 arguments: + +- hparams: a configuration of hyperparameters. - slurm_manager: Slurm cluster manager object (can be None) -- dict: for you to return any values you want (useful in meta-learning, otherwise set to _) +- dict: for you to return any values you want (useful in meta-learning, otherwise set to \_) ```python def main(hparams, cluster, results_dict): @@ -62,13 +65,15 @@ The __main__ function will start training on your **main** function. If you use in hyper parameter optimization mode, this main function will get one set of hyperparameters. If you use it as a simple argument parser you get the default arguments in the argument parser. -So, calling main(hyperparams) runs the model with the default argparse arguments. +So, calling main(hyperparams) runs the model with the default argparse arguments. + ```{.python} main(hyperparams) ``` --- -#### CPU hyperparameter search + +#### CPU hyperparameter search ```{.python} # run a grid search over 20 hyperparameter combinations. @@ -80,7 +85,9 @@ hyperparams.optimize_parallel_cpu( ``` --- -#### Hyperparameter search on a single or multiple GPUs + +#### Hyperparameter search on a single or multiple GPUs + ```{.python} # run a grid search over 20 hyperparameter combinations. hyperparams.optimize_parallel_gpu( @@ -92,8 +99,10 @@ hyperparams.optimize_parallel_gpu( ``` --- -#### Hyperparameter search on a SLURM HPC cluster -```{.python} + +#### Hyperparameter search on a SLURM HPC cluster + +```{.python} def optimize_on_cluster(hyperparams): # enable cluster training cluster = SlurmCluster( @@ -126,6 +135,6 @@ def optimize_on_cluster(hyperparams): job_name=job_display_name ) -# run cluster hyperparameter search +# run cluster hyperparameter search optimize_on_cluster(hyperparams) ``` diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 1b35bfc5d48fe..a08f359aed111 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -158,11 +158,17 @@ class ModelCheckpoint(Callback): filepath: string, path to save the model file. monitor: quantity to monitor. verbose: verbosity mode, 0 or 1. - save_best_only: if `save_best_only=True`, - the latest best model according to - the quantity monitored will not be overwritten. + save_top_k: if `save_top_k == k`, + the best k models according to + the quantity monitored will be saved. + if `save_top_k == 0`, no models are saved. + if `save_top_k == -1`, all models are saved. + Please note that the monitors are checked every `period` epochs. + if `save_top_k >= 2` and the callback is called multiple + times inside an epoch, the name of the saved file will be + appended with a version count starting with `v0`. mode: one of {auto, min, max}. - If `save_best_only=True`, the decision + If `save_top_k != 0`, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For `val_acc`, @@ -176,27 +182,33 @@ class ModelCheckpoint(Callback): """ def __init__(self, filepath, monitor='val_loss', verbose=0, - save_best_only=True, save_weights_only=False, + save_top_k=1, save_weights_only=False, mode='auto', period=1, prefix=''): super(ModelCheckpoint, self).__init__() if ( - save_best_only and + save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0 ): warnings.warn( - f"Checkpoint directory {filepath} exists and is not empty with save_best_only=True." + f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0." "All files in this directory will be deleted when a checkpoint is saved!" ) self.monitor = monitor self.verbose = verbose self.filepath = filepath - self.save_best_only = save_best_only + if not os.path.exists(filepath): + os.makedirs(filepath) + self.save_top_k = save_top_k self.save_weights_only = save_weights_only self.period = period - self.epochs_since_last_save = 0 + self.epochs_since_last_check = 0 self.prefix = prefix + self.best_k_models = {} + # {filename: monitor} + self.kth_best_model = '' + self.best = 0 if mode not in ['auto', 'min', 'max']: warnings.warn( @@ -206,66 +218,112 @@ def __init__(self, filepath, monitor='val_loss', verbose=0, if mode == 'min': self.monitor_op = np.less - self.best = np.Inf + self.kth_value = np.Inf + self.mode = 'min' elif mode == 'max': self.monitor_op = np.greater - self.best = -np.Inf + self.kth_value = -np.Inf + self.mode = 'max' else: if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): self.monitor_op = np.greater - self.best = -np.Inf + self.kth_value = -np.Inf + self.mode = 'max' else: self.monitor_op = np.less - self.best = np.Inf + self.kth_value = np.Inf + self.mode = 'min' - def save_model(self, filepath, overwrite): - dirpath = '/'.join(filepath.split('/')[:-1]) + def _del_model(self, filepath): + dirpath = os.path.dirname(filepath) # make paths - os.makedirs(os.path.dirname(filepath), exist_ok=True) + os.makedirs(dirpath, exist_ok=True) - if overwrite: - for filename in os.listdir(dirpath): - if self.prefix in filename: - path_to_delete = os.path.join(dirpath, filename) - try: - shutil.rmtree(path_to_delete) - except OSError: - os.remove(path_to_delete) + try: + shutil.rmtree(filepath) + except OSError: + os.remove(filepath) + + def _save_model(self, filepath): + dirpath = os.path.dirname(filepath) + + # make paths + os.makedirs(dirpath, exist_ok=True) # delegate the saving to the model self.save_function(filepath) + def check_monitor_top_k(self, current): + less_than_k_models = len(self.best_k_models.keys()) < self.save_top_k + if less_than_k_models: + return True + return self.monitor_op(current, self.best_k_models[self.kth_best_model]) + def on_epoch_end(self, epoch, logs=None): logs = logs or {} - self.epochs_since_last_save += 1 - if self.epochs_since_last_save >= self.period: - self.epochs_since_last_save = 0 - filepath = '{}/{}_ckpt_epoch_{}.ckpt'.format(self.filepath, self.prefix, epoch + 1) - if self.save_best_only: + self.epochs_since_last_check += 1 + + if self.save_top_k == 0: + # no models are saved + return + if self.epochs_since_last_check >= self.period: + self.epochs_since_last_check = 0 + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt' + version_cnt = 0 + while os.path.isfile(filepath): + # this epoch called before + filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt' + version_cnt += 1 + + print(filepath) + + if self.save_top_k != -1: current = logs.get(self.monitor) + if current is None: warnings.warn( f'Can save best model only with {self.monitor} available,' ' skipping.', RuntimeWarning) else: - if self.monitor_op(current, self.best): + if self.check_monitor_top_k(current): + + # remove kth + if len(self.best_k_models.keys()) == self.save_top_k: + delpath = self.kth_best_model + self.best_k_models.pop(self.kth_best_model) + self._del_model(delpath) + + self.best_k_models[filepath] = current + if len(self.best_k_models.keys()) == self.save_top_k: + # monitor dict has reached k elements + if self.mode == 'min': + self.kth_best_model = max(self.best_k_models, key=self.best_k_models.get) + else: + self.kth_best_model = min(self.best_k_models, key=self.best_k_models.get) + self.kth_value = self.best_k_models[self.kth_best_model] + + if self.mode == 'min': + self.best = min(self.best_k_models.values()) + else: + self.best = max(self.best_k_models.values()) if self.verbose > 0: logging.info( - f'\nEpoch {epoch + 1:05d}: {self.monitor} improved' - f' from {self.best:0.5f} to {current:0.5f},' - f' saving model to {filepath}') - self.best = current - self.save_model(filepath, overwrite=True) + f'\nEpoch {epoch:05d}: {self.monitor} reached', + f'{current:0.5f} (best {self.best:0.5f}), saving model to', + f'{filepath} as top {self.save_top_k}') + self._save_model(filepath) else: if self.verbose > 0: logging.info( - f'\nEpoch {epoch + 1:05d}: {self.monitor} did not improve') + f'\nEpoch {epoch:05d}: {self.monitor}', + f'was not in top {self.save_top_k}') + else: if self.verbose > 0: - logging.info(f'\nEpoch {epoch + 1:05d}: saving model to {filepath}') - self.save_model(filepath, overwrite=False) + logging.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') + self._save_model(filepath) class GradientAccumulationScheduler(Callback): diff --git a/tests/test_a_restore_models.py b/tests/test_a_restore_models.py index 22ef9f116b8c1..8e366c8806f5f 100644 --- a/tests/test_a_restore_models.py +++ b/tests/test_a_restore_models.py @@ -131,7 +131,7 @@ def test_load_model_from_checkpoint(): # correct result and ok accuracy assert result == 1, 'training failed to complete' pretrained_model = LightningTestModel.load_from_checkpoint( - os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_1.ckpt") + os.path.join(trainer.checkpoint_callback.filepath, "_ckpt_epoch_0.ckpt") ) # test that hparams loaded correctly diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 667143a88999c..8f2d830c02ee0 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -232,6 +232,118 @@ def test_dp_output_reduce(): assert reduced['b']['c'] == out['b']['c'] +def test_model_checkpoint_options(): + """ + Test ModelCheckpoint options + :return: + """ + def mock_save_function(filepath): + open(filepath, 'a').close() + + hparams = testing_utils.get_hparams() + model = LightningTestModel(hparams) + + # simulated losses + save_dir = testing_utils.init_save_dir() + losses = [10, 9, 2.8, 5, 2.5] + + # ----------------- + # CASE K=-1 (all) + w = ModelCheckpoint(save_dir, save_top_k=-1, verbose=1) + w.save_function = mock_save_function + for i, loss in enumerate(losses): + w.on_epoch_end(i, logs={'val_loss': loss}) + + file_lists = set(os.listdir(save_dir)) + + assert len(file_lists) == len(losses), "Should save all models when save_top_k=-1" + + # verify correct naming + for i in range(0, len(losses)): + assert f'_ckpt_epoch_{i}.ckpt' in file_lists + + testing_utils.clear_save_dir() + + # ----------------- + # CASE K=0 (none) + w = ModelCheckpoint(save_dir, save_top_k=0, verbose=1) + w.save_function = mock_save_function + for i, loss in enumerate(losses): + w.on_epoch_end(i, logs={'val_loss': loss}) + + file_lists = os.listdir(save_dir) + + assert len(file_lists) == 0, "Should save 0 models when save_top_k=0" + + testing_utils.clear_save_dir() + + # ----------------- + # CASE K=1 (2.5, epoch 4) + w = ModelCheckpoint(save_dir, save_top_k=1, verbose=1, prefix='test_prefix') + w.save_function = mock_save_function + for i, loss in enumerate(losses): + w.on_epoch_end(i, logs={'val_loss': loss}) + + file_lists = set(os.listdir(save_dir)) + + assert len(file_lists) == 1, "Should save 1 model when save_top_k=1" + assert 'test_prefix_ckpt_epoch_4.ckpt' in file_lists + + testing_utils.clear_save_dir() + + # ----------------- + # CASE K=2 (2.5 epoch 4, 2.8 epoch 2) + # make sure other files don't get deleted + + w = ModelCheckpoint(save_dir, save_top_k=2, verbose=1) + open(f'{save_dir}/other_file.ckpt', 'a').close() + w.save_function = mock_save_function + for i, loss in enumerate(losses): + w.on_epoch_end(i, logs={'val_loss': loss}) + + file_lists = set(os.listdir(save_dir)) + + assert len(file_lists) == 3, 'Should save 2 model when save_top_k=2' + assert '_ckpt_epoch_4.ckpt' in file_lists + assert '_ckpt_epoch_2.ckpt' in file_lists + assert 'other_file.ckpt' in file_lists + + testing_utils.clear_save_dir() + + # ----------------- + # CASE K=4 (save all 4 models) + # multiple checkpoints within same epoch + + w = ModelCheckpoint(save_dir, save_top_k=4, verbose=1) + w.save_function = mock_save_function + for loss in losses: + w.on_epoch_end(0, logs={'val_loss': loss}) + + file_lists = set(os.listdir(save_dir)) + + assert len(file_lists) == 4, 'Should save all 4 models when save_top_k=4 within same epoch' + + testing_utils.clear_save_dir() + + # ----------------- + # CASE K=3 (save the 2nd, 3rd, 4th model) + # multiple checkpoints within same epoch + + w = ModelCheckpoint(save_dir, save_top_k=3, verbose=1) + w.save_function = mock_save_function + for loss in losses: + w.on_epoch_end(0, logs={'val_loss': loss}) + + file_lists = set(os.listdir(save_dir)) + + assert len(file_lists) == 3, 'Should save 3 models when save_top_k=3' + assert '_ckpt_epoch_0_v2.ckpt' in file_lists + assert '_ckpt_epoch_0_v1.ckpt' in file_lists + assert '_ckpt_epoch_0.ckpt' in file_lists + + testing_utils.clear_save_dir() + + def test_model_freeze_unfreeze(): testing_utils.reset_seed() diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 86d705f9bb741..a956ffad0f137 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -149,7 +149,7 @@ def init_save_dir(): def clear_save_dir(): root_dir = os.path.dirname(os.path.realpath(__file__)) - save_dir = os.path.join(root_dir, 'save_dir') + save_dir = os.path.join(root_dir, 'tests', 'save_dir') if os.path.exists(save_dir): n = RANDOM_FILE_PATHS.pop() shutil.move(save_dir, save_dir + f'_{n}')