Skip to content

Commit

Permalink
Bugfix/lr finder (#1676)
Browse files Browse the repository at this point in the history
* fix early stopping bug

* allow val dataloader

* update CHANGELOG.md

* fix early stopping bug

* allow val dataloader

* update CHANGELOG.md

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
  • Loading branch information
SkafteNicki and Nicki Skafte committed May 4, 2020
1 parent 1077159 commit e865b04
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))

### Deprecated

### Removed
Expand All @@ -32,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/issues/1485))

- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676))

## [0.7.5] - 2020-04-27

### Changed
Expand Down Expand Up @@ -76,7 +80,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Defines shared proc. rank, remove rank from instances (e.g. loggers) ([#1408](https://github.com/PyTorchLightning/pytorch-lightning/pull/1408))
- Updated semantic segmentation example with custom U-Net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
- Disabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600))
- Updated LightningTemplateModel to look more like Colab example ([#1546](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))

### Deprecated

Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _run_lr_finder_internally(self, model: LightningModule):
def lr_find(self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
Expand Down Expand Up @@ -105,7 +106,7 @@ def lr_find(self,
"""
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self._dump_params(model)
self._lr_finder_dump_params(model)

# Prevent going into infinite loop
self.auto_lr_find = False
Expand All @@ -129,8 +130,10 @@ def lr_find(self,
# Accumulation of gradients
self.accumulate_grad_batches = num_accumulation_steps

# Disable standard checkpoint
# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
self.enable_early_stop = False

# Required for saving the model
self.optimizers, self.schedulers = [], [],
Expand All @@ -150,7 +153,9 @@ def lr_find(self,
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

# Fit, lr & loss logged in callback
self.fit(model, train_dataloader=train_dataloader)
self.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)

# Prompt if we stopped early
if self.global_step != num_training:
Expand All @@ -165,13 +170,13 @@ def lr_find(self,
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
self._restore_params(model)
self._lr_finder_restore_params(model)
if self.progress_bar_callback:
self.progress_bar_callback.enable()

return lr_finder

def _dump_params(self, model):
def _lr_finder_dump_params(self, model):
# Prevent going into infinite loop
self._params = {
'auto_lr_find': self.auto_lr_find,
Expand All @@ -181,18 +186,22 @@ def _dump_params(self, model):
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'accumulate_grad_batches': self.accumulate_grad_batches,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'progress_bar_callback': self.progress_bar_callback,
'configure_optimizers': model.configure_optimizers,
}

def _restore_params(self, model):
def _lr_finder_restore_params(self, model):
self.auto_lr_find = self._params['auto_lr_find']
self.logger = self._params['logger']
self.callbacks = self._params['callbacks']
self.max_steps = self._params['max_steps']
self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self._params['accumulate_grad_batches']
self.checkpoint_callback = self._params['checkpoint_callback']
self.early_stop_callback = self._params['early_stop_callback']
self.enable_early_stop = self._params['enable_early_stop']
self.progress_bar_callback = self._params['progress_bar_callback']
model.configure_optimizers = self._params['configure_optimizers']

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class CurrentTestModel(
)

changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
'progress_bar_refresh_rate',
'accumulate_grad_batches',
'progress_bar_refresh_rate', 'early_stop_callback',
'accumulate_grad_batches', 'enable_early_stop',
'checkpoint_callback']
attributes_before = {}
for ca in changed_attributes:
Expand Down

0 comments on commit e865b04

Please sign in to comment.