Skip to content

Commit

Permalink
Bugfix: accumulation and suggestion for learning rate finder (#1801)
Browse files Browse the repository at this point in the history
* fix suggestion being too naive

* fix accumulation error and added new tests

* fix styling

* update CHANGELOG.md

* update based on review

* fix tests

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people committed May 13, 2020
1 parent aefc531 commit 663b900
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 28 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed saving native AMP scaler state (introduced in [#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))

- Fixed missing profiler attribute in add_argparse_args() ArgumentParser ([#1794](https://github.com/PyTorchLightning/pytorch-lightning/pull/1794))

- Fixed accumulation parameter and suggestion method for learning rate finder ([#1801](https://github.com/PyTorchLightning/pytorch-lightning/pull/1801))

## [0.7.5] - 2020-04-27

Expand Down
79 changes: 58 additions & 21 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn


class TrainerLRFinderMixin(ABC):
Expand Down Expand Up @@ -58,7 +59,8 @@ def lr_find(self,
max_lr: float = 1,
num_training: int = 100,
mode: str = 'exponential',
num_accumulation_steps: int = 1):
early_stop_threshold: float = 4.0,
num_accumulation_steps=None):
r"""
lr_find enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.
Expand All @@ -81,7 +83,12 @@ def lr_find(self,
after each batch. If set to 'exponential', will increase learning
rate exponentially.
num_accumulation_steps: number of batches to calculate loss over.
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
num_accumulation_steps: deprepecated, number of batches to calculate loss over.
Set trainer argument ``accumulate_grad_batches`` instead.
Example::
Expand All @@ -104,6 +111,12 @@ def lr_find(self,
trainer.fit(model)
"""
if num_accumulation_steps is not None:
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
" since v0.7.6 and will be removed in 0.9. Please"
" set trainer argument `accumulate_grad_batches` instead.",
DeprecationWarning)

save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self.__lr_finder_dump_params(model)
Expand All @@ -115,7 +128,9 @@ def lr_find(self,
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)

# Use special lr logger callback
self.callbacks = [_LRCallback(num_training, progress_bar_refresh_rate=1)]
self.callbacks = [_LRCallback(num_training,
early_stop_threshold,
progress_bar_refresh_rate=1)]

# No logging
self.logger = None
Expand All @@ -127,9 +142,6 @@ def lr_find(self,
if self.progress_bar_callback:
self.progress_bar_callback.disable()

# Accumulation of gradients
self.accumulate_grad_batches = num_accumulation_steps

# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
Expand All @@ -149,7 +161,6 @@ def lr_find(self,
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
configure_optimizers = model.configure_optimizers
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

# Fit, lr & loss logged in callback
Expand All @@ -164,6 +175,7 @@ def lr_find(self,
# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': self.callbacks[0].lrs,
'loss': self.callbacks[0].losses})
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose

# Reset model state
self.restore(str(save_path), on_gpu=self.on_gpu)
Expand All @@ -184,7 +196,6 @@ def __lr_finder_dump_params(self, model):
'logger': self.logger,
'max_steps': self.max_steps,
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'accumulate_grad_batches': self.accumulate_grad_batches,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
Expand All @@ -198,7 +209,6 @@ def __lr_finder_restore_params(self, model):
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
Expand Down Expand Up @@ -242,6 +252,7 @@ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
self.num_training = num_training

self.results = {}
self._total_batch_idx = 0 # for debug purpose

def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
Expand Down Expand Up @@ -298,30 +309,49 @@ def plot(self, suggest: bool = False, show: bool = False):

return fig

def suggestion(self):
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.
Returns:
lr: suggested initial learning rate to use
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
"""
try:
min_grad = (np.gradient(np.array(self.results["loss"]))).argmin()
self._optimal_idx = min_grad
return self.results["lr"][min_grad]
loss = self.results["loss"][skip_begin:-skip_end]
min_grad = (np.gradient(np.array(loss))).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
except Exception:
log.warning('Failed to compute suggesting for `lr`.'
' There might not be enough points.')
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None


class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch. """
def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, beta: float = 0.98):
each batch.
Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than ``early_stop_threshold*best_loss``
then the search is stopped. To disable, set to ``None``.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.
"""
def __init__(self, num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: bool = False,
beta: float = 0.98):
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
Expand All @@ -332,13 +362,19 @@ def __init__(self, num_training: int, progress_bar_refresh_rate: bool = False, b

def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar:
self.progress_bar.update()

Expand All @@ -350,10 +386,11 @@ def on_batch_end(self, trainer, pl_module):
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)

# Check if we diverging
if current_step > 1 and smoothed_loss > 4 * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()

# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
Expand Down
56 changes: 51 additions & 5 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,15 @@ def test_trainer_reset_correctly(tmpdir):


def test_trainer_arg_bool(tmpdir):

""" Test that setting trainer arg to bool works """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
before_lr = hparams.learning_rate

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
auto_lr_find=True
)

Expand All @@ -95,7 +95,7 @@ def test_trainer_arg_bool(tmpdir):


def test_trainer_arg_str(tmpdir):

""" Test that setting trainer arg to string works """
hparams = EvalModelTemplate.get_default_hparams()
hparams.__dict__['my_fancy_lr'] = 1.0 # update with non-standard field
model = EvalModelTemplate(hparams)
Expand All @@ -104,7 +104,7 @@ def test_trainer_arg_str(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
auto_lr_find='my_fancy_lr'
)

Expand All @@ -115,6 +115,7 @@ def test_trainer_arg_str(tmpdir):


def test_call_to_trainer_method(tmpdir):
""" Test that directly calling the trainer method works """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)
Expand All @@ -123,7 +124,7 @@ def test_call_to_trainer_method(tmpdir):
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
max_epochs=5,
)

lrfinder = trainer.lr_find(model, mode='linear')
Expand All @@ -133,3 +134,48 @@ def test_call_to_trainer_method(tmpdir):

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'


def test_accumulation_and_early_stopping(tmpdir):
""" Test that early stopping of learning rate finder works, and that
accumulation also works for this feature """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

before_lr = hparams.learning_rate
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
accumulate_grad_batches=2
)

lrfinder = trainer.lr_find(model, early_stop_threshold=None)
after_lr = lrfinder.suggestion()

assert before_lr != after_lr, \
'Learning rate was not altered after running learning rate finder'
assert len(lrfinder.results['lr']) == 100, \
'Early stopping for learning rate finder did not work'
assert lrfinder._total_batch_idx == 100 * 2, \
'Accumulation parameter did not work'


def test_suggestion_parameters_work(tmpdir):
""" Test that default skipping does not alter results in basic case """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(hparams)

# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=10,
)

lrfinder = trainer.lr_find(model)
lr1 = lrfinder.suggestion(skip_begin=10) # default
lr2 = lrfinder.suggestion(skip_begin=80) # way too high, should have an impact

assert lr1 != lr2, \
'Skipping parameter did not influence learning rate'

0 comments on commit 663b900

Please sign in to comment.