diff --git a/CHANGELOG.md b/CHANGELOG.md index 270cf30ed5704..5f402d7b8bc49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,8 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) -- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794)) - +- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801)) ## [0.7.5] - 2020-04-27 diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 7de7b298f052b..0ca41d2d54f99 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -15,6 +15,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning import _logger as log from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities import rank_zero_warn class TrainerLRFinderMixin(ABC): @@ -58,7 +59,8 @@ def lr_find(self, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', - num_accumulation_steps: int = 1): + early_stop_threshold: float = 4.0, + num_accumulation_steps=None): r""" lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -81,7 +83,12 @@ def lr_find(self, after each batch. If set to 'exponential', will increase learning rate exponentially. - num_accumulation_steps: number of batches to calculate loss over. + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + num_accumulation_steps: deprepecated, number of batches to calculate loss over. + Set trainer argument ``accumulate_grad_batches`` instead. Example:: @@ -104,6 +111,12 @@ def lr_find(self, trainer.fit(model) """ + if num_accumulation_steps is not None: + rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated" + " since v0.7.6 and will be removed in 0.9. Please" + " set trainer argument `accumulate_grad_batches` instead.", + DeprecationWarning) + save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') self.__lr_finder_dump_params(model) @@ -115,7 +128,9 @@ def lr_find(self, lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback - self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)] + self.callbacks = [_LRCallback(num_training, + early_stop_threshold, + progress_bar_refresh_rate=1)] # No logging self.logger = None @@ -127,9 +142,6 @@ def lr_find(self, if self.progress_bar_callback: self.progress_bar_callback.disable() - # Accumulation of gradients - self.accumulate_grad_batches = num_accumulation_steps - # Disable standard checkpoint & early stopping self.checkpoint_callback = False self.early_stop_callback = None @@ -149,7 +161,6 @@ def lr_find(self, raise MisconfigurationException( f'`model.configure_optimizers()` returned {len(optimizers)}, but' ' learning rate finder only works with single optimizer') - configure_optimizers = model.configure_optimizers model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0]) # Fit, lr & loss logged in callback @@ -164,6 +175,7 @@ def lr_find(self, # Transfer results from callback to lr finder object lr_finder.results.update({'lr': self.callbacks[0].lrs, 'loss': self.callbacks[0].losses}) + lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose # Reset model state self.restore(str(save_path), on_gpu=self.on_gpu) @@ -184,7 +196,6 @@ def __lr_finder_dump_params(self, model): 'logger': self.logger, 'max_steps': self.max_steps, 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, - 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, 'early_stop_callback': self.early_stop_callback, 'enable_early_stop': self.enable_early_stop, @@ -198,7 +209,6 @@ def __lr_finder_restore_params(self, model): self.callbacks = self.__dumped_params['callbacks'] self.max_steps = self.__dumped_params['max_steps'] self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate'] - self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches'] self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] self.early_stop_callback = self.__dumped_params['early_stop_callback'] self.enable_early_stop = self.__dumped_params['enable_early_stop'] @@ -242,6 +252,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.num_training = num_training self.results = {} + self._total_batch_idx = 0 # for debug purpose def _get_new_optimizer(self, optimizer: torch.optim.Optimizer): """ Construct a new `configure_optimizers()` method, that has a optimizer @@ -298,30 +309,49 @@ def plot(self, suggest: bool = False, show: bool = False): return fig - def suggestion(self): + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): """ This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. Returns: lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates """ try: - min_grad = (np.gradient(np.array(self.results["loss"]))).argmin() - self._optimal_idx = min_grad - return self.results["lr"][min_grad] + loss = self.results["loss"][skip_begin:-skip_end] + min_grad = (np.gradient(np.array(loss))).argmin() + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] except Exception: - log.warning('Failed to compute suggesting for `lr`.' - ' There might not be enough points.') + log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') self._optimal_idx = None class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after - each batch. """ - def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98): + each batch. + + Args: + num_training: number of iterations done by the learning rate finder + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than ``early_stop_threshold*best_loss`` + then the search is stopped. To disable, set to ``None``. + progress_bar_refresh_rate: rate to refresh the progress bar for + the learning rate finder + beta: smoothing value, the loss being logged is a running average of + loss values logged until now. ``beta`` controls the forget rate i.e. + if ``beta=0`` all past information is ignored. + + """ + def __init__(self, num_training: int, + early_stop_threshold: float = 4.0, + progress_bar_refresh_rate: bool = False, + beta: float = 0.98): self.num_training = num_training + self.early_stop_threshold = early_stop_threshold self.beta = beta self.losses = [] self.lrs = [] @@ -332,6 +362,9 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b def on_batch_start(self, trainer, pl_module): """ Called before each training batch, logs the lr that will be used """ + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + if self.progress_bar_refresh_rate and self.progress_bar is None: self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) @@ -339,6 +372,9 @@ def on_batch_start(self, trainer, pl_module): def on_batch_end(self, trainer, pl_module): """ Called when the training batch ends, logs the calculated loss """ + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + if self.progress_bar: self.progress_bar.update() @@ -350,10 +386,11 @@ def on_batch_end(self, trainer, pl_module): smoothed_loss = self.avg_loss / (1 - self.beta**current_step) # Check if we diverging - if current_step > 1 and smoothed_loss > 4 * self.best_loss: - trainer.max_steps = current_step # stop signal - if self.progress_bar: - self.progress_bar.close() + if self.early_stop_threshold is not None: + if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: + trainer.max_steps = current_step # stop signal + if self.progress_bar: + self.progress_bar.close() # Save best loss for diverging checking if smoothed_loss < self.best_loss or current_step == 1: diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index f46154d8cba72..b0f6b6d3a0829 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -76,7 +76,7 @@ def test_trainer_reset_correctly(tmpdir): def test_trainer_arg_bool(tmpdir): - + """ Test that setting trainer arg to bool works """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) before_lr = hparams.learning_rate @@ -84,7 +84,7 @@ def test_trainer_arg_bool(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, auto_lr_find=True ) @@ -95,7 +95,7 @@ def test_trainer_arg_bool(tmpdir): def test_trainer_arg_str(tmpdir): - + """ Test that setting trainer arg to string works """ hparams = EvalModelTemplate.get_default_hparams() hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field model = EvalModelTemplate(hparams) @@ -104,7 +104,7 @@ def test_trainer_arg_str(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, auto_lr_find='my_fancy_lr' ) @@ -115,6 +115,7 @@ def test_trainer_arg_str(tmpdir): def test_call_to_trainer_method(tmpdir): + """ Test that directly calling the trainer method works """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) @@ -123,7 +124,7 @@ def test_call_to_trainer_method(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, ) lrfinder = trainer.lr_find(model, mode='linear') @@ -133,3 +134,48 @@ def test_call_to_trainer_method(tmpdir): assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' + + +def test_accumulation_and_early_stopping(tmpdir): + """ Test that early stopping of learning rate finder works, and that + accumulation also works for this feature """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + before_lr = hparams.learning_rate + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + accumulate_grad_batches=2 + ) + + lrfinder = trainer.lr_find(model, early_stop_threshold=None) + after_lr = lrfinder.suggestion() + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + assert len(lrfinder.results['lr']) == 100, \ + 'Early stopping for learning rate finder did not work' + assert lrfinder._total_batch_idx == 100 * 2, \ + 'Accumulation parameter did not work' + + +def test_suggestion_parameters_work(tmpdir): + """ Test that default skipping does not alter results in basic case """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=10, + ) + + lrfinder = trainer.lr_find(model) + lr1 = lrfinder.suggestion(skip_begin=10) # default + lr2 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact + + assert lr1 != lr2, \ + 'Skipping parameter did not influence learning rate'