Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: ddp pickle fix for learning rate finder #1834

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 41 additions & 72 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only


class TrainerLRFinderMixin(ABC):
Expand Down Expand Up @@ -124,13 +124,12 @@ def lr_find(self,
# Prevent going into infinite loop
self.auto_lr_find = False

# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
# Initialize lr finder callback
lr_finder = LRFinderCallback(mode, min_lr, max_lr, num_training,
early_stop_threshold)

# Use special lr logger callback
self.callbacks = [_LRCallback(num_training,
early_stop_threshold,
progress_bar_refresh_rate=1)]
self.callbacks = [lr_finder]

# No logging
self.logger = None
Expand Down Expand Up @@ -161,22 +160,19 @@ def lr_find(self,
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
model.configure_optimizers = PatchOptimizer(optimizers[0], min_lr, max_lr,
num_training, mode)

# Fit, lr & loss logged in callback
self.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose

# Prompt if we stopped early
if self.global_step != num_training:
log.info('LR finder stopped early due to diverging loss.')

# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': self.callbacks[0].lrs,
'loss': self.callbacks[0].losses})
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose

# Reset model state
self.restore(str(save_path), on_gpu=self.on_gpu)
os.remove(save_path)
Expand Down Expand Up @@ -217,8 +213,8 @@ def __lr_finder_restore_params(self, model):
del self.__dumped_params


class _LRFinder(object):
""" LR finder object. This object stores the results of Trainer.lr_find().
class LRFinderCallback(Callback):
""" LR finder callback. This object stores the results of Trainer.lr_find().

Args:
mode: either `linear` or `exponential`, how to increase lr after each step
Expand All @@ -242,40 +238,26 @@ class _LRFinder(object):
# Get suggestion
lr = lr_finder.suggestion()
"""
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
def __init__(self, mode: str, lr_min: float, lr_max: float,
num_training: int, early_stop_threshold: float = 4.0,
beta: float = 0.98, progress_bar_refresh_rate: bool = True):
assert mode in ('linear', 'exponential'), \
'mode should be either `linear` or `exponential`'

self.mode = mode
self.lr_min = lr_min
self.lr_max = lr_max
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta

self.results = {}
self.results = {'lr': [], 'loss': []}
self._total_batch_idx = 0 # for debug purpose

def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
with initial lr set to lr_min and a scheduler that will either
linearly or exponentially increase the lr to lr_max in num_training steps.

Args:
optimizer: instance of `torch.optim.Optimizer`

"""
new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr

args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)

def configure_optimizers():
return [optimizer], [{'scheduler': scheduler,
'interval': 'step'}]

return configure_optimizers
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

def plot(self, suggest: bool = False, show: bool = False):
""" Plot results from lr_find run
Expand Down Expand Up @@ -328,38 +310,7 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None


class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch.

Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than ``early_stop_threshold*best_loss``
then the search is stopped. To disable, set to ``None``.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.

"""
def __init__(self, num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: bool = False,
beta: float = 0.98):
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None

@rank_zero_only
def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
Expand All @@ -368,8 +319,9 @@ def on_batch_start(self, trainer, pl_module):
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)

self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
self.results['lr'].append(trainer.lr_schedulers[0]['scheduler'].lr[0])

@rank_zero_only
def on_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
Expand All @@ -396,7 +348,24 @@ def on_batch_end(self, trainer, pl_module):
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss

self.losses.append(smoothed_loss)
self.results['loss'].append(smoothed_loss)


class PatchOptimizer(object):
def __init__(self, optimizers, min_lr, max_lr, num_training, mode):

new_lrs = [min_lr] * len(optimizers.param_groups)
for param_group, new_lr in zip(optimizers.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr
args = (optimizers, max_lr, num_training)
self.optimizer = optimizers
self.scheduler = _LinearLR(*args) if mode == 'linear' else _ExponentialLR(*args)

self.patch_loader_code = str(self.__call__.__code__)

def __call__(self):
return [self.optimizer], [{'scheduler': self.scheduler, 'interval': 'step'}]


class _LinearLR(_LRScheduler):
Expand Down