From 59a0a8972425f3c867d0696916b853d4c7ec3215 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 12 May 2020 12:06:32 +0200 Subject: [PATCH 01/10] fix suggestion being too naive --- pytorch_lightning/trainer/lr_finder.py | 11 +++++++++-- tests/trainer/test_lr_finder.py | 6 +++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 7de7b298f052b..59605badf34ad 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -298,16 +298,23 @@ def plot(self, suggest: bool = False, show: bool = False): return fig - def suggestion(self): + def suggestion(self, skip_begin: int = 10, skip_end: int = 1): """ This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. Returns: lr: suggested initial learning rate to use + skip_begin: how many samples to skip in the beginning. + Prevent too naive estimates + + skip_end: how many samples to skip in the end. + Prevent too optimistic estimates + """ try: - min_grad = (np.gradient(np.array(self.results["loss"]))).argmin() + loss = self.results["loss"][skip_begin:-skip_end] + min_grad = (np.gradient(np.array(loss))).argmin() self._optimal_idx = min_grad return self.results["lr"][min_grad] except Exception: diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index f46154d8cba72..1964c745140e1 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -84,7 +84,7 @@ def test_trainer_arg_bool(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, auto_lr_find=True ) @@ -104,7 +104,7 @@ def test_trainer_arg_str(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, auto_lr_find='my_fancy_lr' ) @@ -123,7 +123,7 @@ def test_call_to_trainer_method(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=1, + max_epochs=5, ) lrfinder = trainer.lr_find(model, mode='linear') From ce010f8d9a09072a4d0e686d01bbb7a94cddf0fd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 12 May 2020 16:33:50 +0200 Subject: [PATCH 02/10] fix accumulation error and added new tests --- pytorch_lightning/trainer/lr_finder.py | 85 +++++++++++++++----------- tests/trainer/test_lr_finder.py | 56 ++++++++++++++++- 2 files changed, 104 insertions(+), 37 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 59605badf34ad..6361ecb521a69 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -15,6 +15,7 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning import _logger as log from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities import rank_zero_warn class TrainerLRFinderMixin(ABC): @@ -58,7 +59,8 @@ def lr_find(self, max_lr: float = 1, num_training: int = 100, mode: str = 'exponential', - num_accumulation_steps: int = 1): + early_stop_threshold: float = 4.0, + num_accumulation_steps=None): r""" lr_find enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate. @@ -81,7 +83,12 @@ def lr_find(self, after each batch. If set to 'exponential', will increase learning rate exponentially. - num_accumulation_steps: number of batches to calculate loss over. + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + num_accumulation_steps: deprepecated, number of batches to calculate loss over. + Set trainer argument `accumulate_grad_batches` instead. Example:: @@ -104,6 +111,11 @@ def lr_find(self, trainer.fit(model) """ + if num_accumulation_steps is not None: + rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated." + "Please set trainer argument `accumulate_grad_batches`" + " instead.", DeprecationWarning) + save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') self.__lr_finder_dump_params(model) @@ -115,7 +127,9 @@ def lr_find(self, lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback - self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)] + self.callbacks = [_LRCallback(num_training, + early_stop_threshold, + progress_bar_refresh_rate=1)] # No logging self.logger = None @@ -127,9 +141,6 @@ def lr_find(self, if self.progress_bar_callback: self.progress_bar_callback.disable() - # Accumulation of gradients - self.accumulate_grad_batches = num_accumulation_steps - # Disable standard checkpoint & early stopping self.checkpoint_callback = False self.early_stop_callback = None @@ -149,7 +160,6 @@ def lr_find(self, raise MisconfigurationException( f'`model.configure_optimizers()` returned {len(optimizers)}, but' ' learning rate finder only works with single optimizer') - configure_optimizers = model.configure_optimizers model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0]) # Fit, lr & loss logged in callback @@ -164,7 +174,8 @@ def lr_find(self, # Transfer results from callback to lr finder object lr_finder.results.update({'lr': self.callbacks[0].lrs, 'loss': self.callbacks[0].losses}) - + lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose + # Reset model state self.restore(str(save_path), on_gpu=self.on_gpu) os.remove(save_path) @@ -184,7 +195,6 @@ def __lr_finder_dump_params(self, model): 'logger': self.logger, 'max_steps': self.max_steps, 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, - 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, 'early_stop_callback': self.early_stop_callback, 'enable_early_stop': self.enable_early_stop, @@ -198,7 +208,6 @@ def __lr_finder_restore_params(self, model): self.callbacks = self.__dumped_params['callbacks'] self.max_steps = self.__dumped_params['max_steps'] self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate'] - self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches'] self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] self.early_stop_callback = self.__dumped_params['early_stop_callback'] self.enable_early_stop = self.__dumped_params['enable_early_stop'] @@ -240,8 +249,9 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.lr_min = lr_min self.lr_max = lr_max self.num_training = num_training - + self.results = {} + self._total_batch_idx = 0 # for debug purpose def _get_new_optimizer(self, optimizer: torch.optim.Optimizer): """ Construct a new `configure_optimizers()` method, that has a optimizer @@ -307,7 +317,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - + skip_end: how many samples to skip in the end. Prevent too optimistic estimates @@ -315,8 +325,8 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): try: loss = self.results["loss"][skip_begin:-skip_end] min_grad = (np.gradient(np.array(loss))).argmin() - self._optimal_idx = min_grad - return self.results["lr"][min_grad] + self._optimal_idx = min_grad + skip_begin + return self.results["lr"][self._optimal_idx] except Exception: log.warning('Failed to compute suggesting for `lr`.' ' There might not be enough points.') @@ -327,8 +337,12 @@ class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch. """ - def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98): + def __init__(self, num_training: int, + early_stop_threshold: float = 4.0, + progress_bar_refresh_rate: bool = False, + beta: float = 0.98): self.num_training = num_training + self.early_stop_threshold = early_stop_threshold self.beta = beta self.losses = [] self.lrs = [] @@ -339,34 +353,37 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b def on_batch_start(self, trainer, pl_module): """ Called before each training batch, logs the lr that will be used """ - if self.progress_bar_refresh_rate and self.progress_bar is None: - self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0: + if self.progress_bar_refresh_rate and self.progress_bar is None: + self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) - self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) + self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) def on_batch_end(self, trainer, pl_module): """ Called when the training batch ends, logs the calculated loss """ - if self.progress_bar: - self.progress_bar.update() + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0: + if self.progress_bar: + self.progress_bar.update() - current_loss = trainer.running_loss.last().item() - current_step = trainer.global_step + 1 # remove the +1 in 1.0 + current_loss = trainer.running_loss.last().item() + current_step = trainer.global_step + 1 # remove the +1 in 1.0 - # Avg loss (loss with momentum) + smoothing - self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss - smoothed_loss = self.avg_loss / (1 - self.beta**current_step) + # Avg loss (loss with momentum) + smoothing + self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss + smoothed_loss = self.avg_loss / (1 - self.beta**current_step) - # Check if we diverging - if current_step > 1 and smoothed_loss > 4 * self.best_loss: - trainer.max_steps = current_step # stop signal - if self.progress_bar: - self.progress_bar.close() + # Check if we diverging + if self.early_stop_threshold is not None: + if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: + trainer.max_steps = current_step # stop signal + if self.progress_bar: + self.progress_bar.close() - # Save best loss for diverging checking - if smoothed_loss < self.best_loss or current_step == 1: - self.best_loss = smoothed_loss + # Save best loss for diverging checking + if smoothed_loss < self.best_loss or current_step == 1: + self.best_loss = smoothed_loss - self.losses.append(smoothed_loss) + self.losses.append(smoothed_loss) class _LinearLR(_LRScheduler): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 1964c745140e1..ece59ae894b09 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -76,7 +76,7 @@ def test_trainer_reset_correctly(tmpdir): def test_trainer_arg_bool(tmpdir): - + """ Test that setting trainer arg to bool works """ hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) before_lr = hparams.learning_rate @@ -95,7 +95,7 @@ def test_trainer_arg_bool(tmpdir): def test_trainer_arg_str(tmpdir): - + """ Test that setting trainer arg to string works """ hparams = EvalModelTemplate.get_default_hparams() hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field model = EvalModelTemplate(hparams) @@ -115,7 +115,8 @@ def test_trainer_arg_str(tmpdir): def test_call_to_trainer_method(tmpdir): - + """ Test that directly calling the trainer method works """ + hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) @@ -133,3 +134,52 @@ def test_call_to_trainer_method(tmpdir): assert before_lr != after_lr, \ 'Learning rate was not altered after running learning rate finder' + + +def test_accumulation_and_early_stopping(tmpdir): + """ Test that early stopping of learning rate finder works, and that + accumulation also works for this feature """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + before_lr = hparams.learning_rate + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + accumulate_grad_batches=2 + ) + + lrfinder = trainer.lr_find(model, early_stop_threshold=None) + after_lr = lrfinder.suggestion() + + assert before_lr != after_lr, \ + 'Learning rate was not altered after running learning rate finder' + assert len(lrfinder.results['lr']) == 100, \ + 'Early stopping for learning rate finder did not work' + assert lrfinder._total_batch_idx == 100*2, \ + 'Accumulation parameter did not work' + + +def test_suggestion_parameters_work(tmpdir): + """ Test that default skipping does not alter results in basic case """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=10, + ) + + lrfinder = trainer.lr_find(model) + lr1 = lrfinder.suggestion() + lr2 = lrfinder.suggestion(skip_begin=0) + lr3 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact + + assert lr1 == lr2, \ + 'Default skipping parameter should not influence suggested learning rate' + assert lr1 != lr3, \ + 'Skipping parameter did not influence learning rate' + \ No newline at end of file From c502093a1e80d0c8110b7f1bcbbd37af25010af4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 12 May 2020 16:36:11 +0200 Subject: [PATCH 03/10] fix styling --- pytorch_lightning/trainer/lr_finder.py | 10 +++++----- tests/trainer/test_lr_finder.py | 11 +++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 6361ecb521a69..15fef3ec8df70 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -127,7 +127,7 @@ def lr_find(self, lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) # Use special lr logger callback - self.callbacks = [_LRCallback(num_training, + self.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)] @@ -175,7 +175,7 @@ def lr_find(self, lr_finder.results.update({'lr': self.callbacks[0].lrs, 'loss': self.callbacks[0].losses}) lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose - + # Reset model state self.restore(str(save_path), on_gpu=self.on_gpu) os.remove(save_path) @@ -249,7 +249,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int): self.lr_min = lr_min self.lr_max = lr_max self.num_training = num_training - + self.results = {} self._total_batch_idx = 0 # for debug purpose @@ -337,9 +337,9 @@ class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch. """ - def __init__(self, num_training: int, + def __init__(self, num_training: int, early_stop_threshold: float = 4.0, - progress_bar_refresh_rate: bool = False, + progress_bar_refresh_rate: bool = False, beta: float = 0.98): self.num_training = num_training self.early_stop_threshold = early_stop_threshold diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index ece59ae894b09..1a7393883e887 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -116,7 +116,7 @@ def test_trainer_arg_str(tmpdir): def test_call_to_trainer_method(tmpdir): """ Test that directly calling the trainer method works """ - + hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) @@ -139,7 +139,7 @@ def test_call_to_trainer_method(tmpdir): def test_accumulation_and_early_stopping(tmpdir): """ Test that early stopping of learning rate finder works, and that accumulation also works for this feature """ - + hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) @@ -157,13 +157,13 @@ def test_accumulation_and_early_stopping(tmpdir): 'Learning rate was not altered after running learning rate finder' assert len(lrfinder.results['lr']) == 100, \ 'Early stopping for learning rate finder did not work' - assert lrfinder._total_batch_idx == 100*2, \ + assert lrfinder._total_batch_idx == 100 * 2, \ 'Accumulation parameter did not work' def test_suggestion_parameters_work(tmpdir): """ Test that default skipping does not alter results in basic case """ - + hparams = EvalModelTemplate.get_default_hparams() model = EvalModelTemplate(hparams) @@ -172,7 +172,7 @@ def test_suggestion_parameters_work(tmpdir): default_save_path=tmpdir, max_epochs=10, ) - + lrfinder = trainer.lr_find(model) lr1 = lrfinder.suggestion() lr2 = lrfinder.suggestion(skip_begin=0) @@ -182,4 +182,3 @@ def test_suggestion_parameters_work(tmpdir): 'Default skipping parameter should not influence suggested learning rate' assert lr1 != lr3, \ 'Skipping parameter did not influence learning rate' - \ No newline at end of file From 3d6efb6afd3304acc55462af0cd291574120d5e7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 12 May 2020 16:47:03 +0200 Subject: [PATCH 04/10] update CHANGELOG.md --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 270cf30ed5704..5f402d7b8bc49 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,8 +68,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561)) -- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794)) - +- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801)) ## [0.7.5] - 2020-04-27 From f1188823812569adc866051f54bf94ac44d4de62 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 13 May 2020 15:41:40 +0200 Subject: [PATCH 05/10] update based on review --- pytorch_lightning/trainer/lr_finder.py | 77 ++++++++++++++++---------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 15fef3ec8df70..4036d104b48e0 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -112,9 +112,10 @@ def lr_find(self, """ if num_accumulation_steps is not None: - rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated." - "Please set trainer argument `accumulate_grad_batches`" - " instead.", DeprecationWarning) + rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated" + " since v0.8.0 and will be removed in v1.0.0. Please" + " set trainer argument `accumulate_grad_batches` instead.", + DeprecationWarning) save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') @@ -328,15 +329,31 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): self._optimal_idx = min_grad + skip_begin return self.results["lr"][self._optimal_idx] except Exception: - log.warning('Failed to compute suggesting for `lr`.' - ' There might not be enough points.') + log.exception('Failed to compute suggesting for `lr`.' + ' There might not be enough points.') self._optimal_idx = None class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after - each batch. """ + each batch. + Args: + + num_training: number of iterations done by the learning rate finder + + early_stop_threshold: threshold for stopping the search. If the + loss at any point is larger than early_stop_threshold*best_loss + then the search is stopped. To disable, set to None. + + progress_bar_refresh_rate: rate to refresh the progress bar for + the learning rate finder + + beta: smoothing value, the loss being logged is a running average of + loss values logged until now. beta controls the forget rate i.e. + if beta=0 all past information is ignored. + + """ def __init__(self, num_training: int, early_stop_threshold: float = 4.0, progress_bar_refresh_rate: bool = False, @@ -353,37 +370,41 @@ def __init__(self, num_training: int, def on_batch_start(self, trainer, pl_module): """ Called before each training batch, logs the lr that will be used """ - if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0: - if self.progress_bar_refresh_rate and self.progress_bar is None: - self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return - self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) + if self.progress_bar_refresh_rate and self.progress_bar is None: + self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training) + + self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0]) def on_batch_end(self, trainer, pl_module): """ Called when the training batch ends, logs the calculated loss """ - if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0: - if self.progress_bar: - self.progress_bar.update() + if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0: + return + + if self.progress_bar: + self.progress_bar.update() - current_loss = trainer.running_loss.last().item() - current_step = trainer.global_step + 1 # remove the +1 in 1.0 + current_loss = trainer.running_loss.last().item() + current_step = trainer.global_step + 1 # remove the +1 in 1.0 - # Avg loss (loss with momentum) + smoothing - self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss - smoothed_loss = self.avg_loss / (1 - self.beta**current_step) + # Avg loss (loss with momentum) + smoothing + self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss + smoothed_loss = self.avg_loss / (1 - self.beta**current_step) - # Check if we diverging - if self.early_stop_threshold is not None: - if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: - trainer.max_steps = current_step # stop signal - if self.progress_bar: - self.progress_bar.close() + # Check if we diverging + if self.early_stop_threshold is not None: + if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss: + trainer.max_steps = current_step # stop signal + if self.progress_bar: + self.progress_bar.close() - # Save best loss for diverging checking - if smoothed_loss < self.best_loss or current_step == 1: - self.best_loss = smoothed_loss + # Save best loss for diverging checking + if smoothed_loss < self.best_loss or current_step == 1: + self.best_loss = smoothed_loss - self.losses.append(smoothed_loss) + self.losses.append(smoothed_loss) class _LinearLR(_LRScheduler): From 02ba1fe362c48a95ec555e134fe18f6ccb51f27f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 13 May 2020 17:34:28 +0200 Subject: [PATCH 06/10] fix tests --- tests/trainer/test_lr_finder.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 1a7393883e887..b0f6b6d3a0829 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -174,11 +174,8 @@ def test_suggestion_parameters_work(tmpdir): ) lrfinder = trainer.lr_find(model) - lr1 = lrfinder.suggestion() - lr2 = lrfinder.suggestion(skip_begin=0) - lr3 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact + lr1 = lrfinder.suggestion(skip_begin=10) # default + lr2 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact - assert lr1 == lr2, \ - 'Default skipping parameter should not influence suggested learning rate' - assert lr1 != lr3, \ + assert lr1 != lr2, \ 'Skipping parameter did not influence learning rate' From 8c0a254afbf8c2d74dab443b92fa823f336e32a3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 13 May 2020 19:22:38 +0200 Subject: [PATCH 07/10] Apply suggestions from code review --- pytorch_lightning/trainer/lr_finder.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 4036d104b48e0..acb7811af354c 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -113,7 +113,7 @@ def lr_find(self, """ if num_accumulation_steps is not None: rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated" - " since v0.8.0 and will be removed in v1.0.0. Please" + " since v0.7.6 and will be removed in 0.9. Please" " set trainer argument `accumulate_grad_batches` instead.", DeprecationWarning) @@ -338,20 +338,17 @@ class _LRCallback(Callback): """ Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch. - Args: + Args: num_training: number of iterations done by the learning rate finder - early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than early_stop_threshold*best_loss - then the search is stopped. To disable, set to None. - + loss at any point is larger than `early_stop_threshold*best_loss` + then the search is stopped. To disable, set to `None`. progress_bar_refresh_rate: rate to refresh the progress bar for the learning rate finder - beta: smoothing value, the loss being logged is a running average of - loss values logged until now. beta controls the forget rate i.e. - if beta=0 all past information is ignored. + loss values logged until now. `beta` controls the forget rate i.e. + if `beta=0` all past information is ignored. """ def __init__(self, num_training: int, From 7c017cb78df2ffdfbeac37814cd81df110b3c359 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 13 May 2020 19:23:50 +0200 Subject: [PATCH 08/10] Apply suggestions from code review --- pytorch_lightning/trainer/lr_finder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index acb7811af354c..b5b1e74830087 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -342,8 +342,8 @@ class _LRCallback(Callback): Args: num_training: number of iterations done by the learning rate finder early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than `early_stop_threshold*best_loss` - then the search is stopped. To disable, set to `None`. + loss at any point is larger than `early_stop_threshold*best_loss` + then the search is stopped. To disable, set to `None`. progress_bar_refresh_rate: rate to refresh the progress bar for the learning rate finder beta: smoothing value, the loss being logged is a running average of From 5b21a0d2bd1a6a12b36944415f89bdaa0f89d23f Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 13 May 2020 19:44:48 +0200 Subject: [PATCH 09/10] Apply suggestions from code review --- pytorch_lightning/trainer/lr_finder.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index b5b1e74830087..17e35e97d9ce6 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -88,7 +88,7 @@ def lr_find(self, then the search is stopped. To disable, set to None. num_accumulation_steps: deprepecated, number of batches to calculate loss over. - Set trainer argument `accumulate_grad_batches` instead. + Set trainer argument ``accumulate_grad_batches`` instead. Example:: @@ -315,10 +315,8 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): Returns: lr: suggested initial learning rate to use - skip_begin: how many samples to skip in the beginning. Prevent too naive estimates - skip_end: how many samples to skip in the end. Prevent too optimistic estimates @@ -329,8 +327,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): self._optimal_idx = min_grad + skip_begin return self.results["lr"][self._optimal_idx] except Exception: - log.exception('Failed to compute suggesting for `lr`.' - ' There might not be enough points.') + log.exception('Failed to compute suggesting for `lr`. There might not be enough points.') self._optimal_idx = None @@ -342,13 +339,13 @@ class _LRCallback(Callback): Args: num_training: number of iterations done by the learning rate finder early_stop_threshold: threshold for stopping the search. If the - loss at any point is larger than `early_stop_threshold*best_loss` - then the search is stopped. To disable, set to `None`. + loss at any point is larger than ``early_stop_threshold*best_loss`` + then the search is stopped. To disable, set to ``None``. progress_bar_refresh_rate: rate to refresh the progress bar for the learning rate finder beta: smoothing value, the loss being logged is a running average of - loss values logged until now. `beta` controls the forget rate i.e. - if `beta=0` all past information is ignored. + loss values logged until now. ``beta`` controls the forget rate i.e. + if ``beta=0`` all past information is ignored. """ def __init__(self, num_training: int, From 04dd17942bb622b549b2b59c459a9f32dd5a7fed Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 13 May 2020 20:16:15 +0200 Subject: [PATCH 10/10] Apply suggestions from code review --- pytorch_lightning/trainer/lr_finder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 17e35e97d9ce6..0ca41d2d54f99 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -315,10 +315,8 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): Returns: lr: suggested initial learning rate to use - skip_begin: how many samples to skip in the beginning. - Prevent too naive estimates - skip_end: how many samples to skip in the end. - Prevent too optimistic estimates + skip_begin: how many samples to skip in the beginning. Prevent too naive estimates + skip_end: how many samples to skip in the end. Prevent too optimistic estimates """ try: