Skip to content

Commit

Permalink
update based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicki Skafte authored and Borda committed May 13, 2020
1 parent 3d6efb6 commit f118882
Showing 1 changed file with 49 additions and 28 deletions.
77 changes: 49 additions & 28 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ def lr_find(self,
"""
if num_accumulation_steps is not None:
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated."
"Please set trainer argument `accumulate_grad_batches`"
" instead.", DeprecationWarning)
rank_zero_warn("Argument `num_accumulation_steps` has been deprepecated"
" since v0.8.0 and will be removed in v1.0.0. Please"
" set trainer argument `accumulate_grad_batches` instead.",
DeprecationWarning)

save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

Expand Down Expand Up @@ -328,15 +329,31 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
except Exception:
log.warning('Failed to compute suggesting for `lr`.'
' There might not be enough points.')
log.exception('Failed to compute suggesting for `lr`.'
' There might not be enough points.')
self._optimal_idx = None


class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch. """
each batch.
Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. beta controls the forget rate i.e.
if beta=0 all past information is ignored.
"""
def __init__(self, num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: bool = False,
Expand All @@ -353,37 +370,41 @@ def __init__(self, num_training: int,

def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0:
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])

def on_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches == 0:
if self.progress_bar:
self.progress_bar.update()
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

if self.progress_bar:
self.progress_bar.update()

current_loss = trainer.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0
current_loss = trainer.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0

# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)

# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()
# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()

# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss
# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss

self.losses.append(smoothed_loss)
self.losses.append(smoothed_loss)


class _LinearLR(_LRScheduler):
Expand Down

0 comments on commit f118882

Please sign in to comment.