Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/lr finder #1676

Merged
merged 7 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))

- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))

### Deprecated

### Removed
Expand All @@ -30,6 +32,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/issues/1485))

- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676))

## [0.7.5] - 2020-04-27

### Changed
Expand Down Expand Up @@ -74,7 +78,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Defines shared proc. rank, remove rank from instances (e.g. loggers) ([#1408](https://github.com/PyTorchLightning/pytorch-lightning/pull/1408))
- Updated semantic segmentation example with custom U-Net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
- Disabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600))
- Updated LightningTemplateModel to look more like Colab example ([#1546](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))

### Deprecated

Expand Down
21 changes: 15 additions & 6 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _run_lr_finder_internally(self, model: LightningModule):
def lr_find(self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
Borda marked this conversation as resolved.
Show resolved Hide resolved
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
Expand Down Expand Up @@ -105,7 +106,7 @@ def lr_find(self,
"""
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self._dump_params(model)
self._lr_finder_dump_params(model)

# Prevent going into infinite loop
self.auto_lr_find = False
Expand All @@ -129,8 +130,10 @@ def lr_find(self,
# Accumulation of gradients
self.accumulate_grad_batches = num_accumulation_steps

# Disable standard checkpoint
# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
self.enable_early_stop = False

# Required for saving the model
self.optimizers, self.schedulers = [], [],
Expand All @@ -150,7 +153,9 @@ def lr_find(self,
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

# Fit, lr & loss logged in callback
self.fit(model, train_dataloader=train_dataloader)
self.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)

# Prompt if we stopped early
if self.global_step != num_training:
Expand All @@ -165,13 +170,13 @@ def lr_find(self,
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
self._restore_params(model)
self._lr_finder_restore_params(model)
if self.progress_bar_callback:
self.progress_bar_callback.enable()

return lr_finder

def _dump_params(self, model):
def _lr_finder_dump_params(self, model):
# Prevent going into infinite loop
self._params = {
'auto_lr_find': self.auto_lr_find,
Expand All @@ -181,18 +186,22 @@ def _dump_params(self, model):
'progress_bar_refresh_rate': self.progress_bar_refresh_rate,
'accumulate_grad_batches': self.accumulate_grad_batches,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'enable_early_stop': self.enable_early_stop,
'progress_bar_callback': self.progress_bar_callback,
'configure_optimizers': model.configure_optimizers,
}

def _restore_params(self, model):
def _lr_finder_restore_params(self, model):
self.auto_lr_find = self._params['auto_lr_find']
self.logger = self._params['logger']
self.callbacks = self._params['callbacks']
self.max_steps = self._params['max_steps']
self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self._params['accumulate_grad_batches']
self.checkpoint_callback = self._params['checkpoint_callback']
self.early_stop_callback = self._params['early_stop_callback']
self.enable_early_stop = self._params['enable_early_stop']
self.progress_bar_callback = self._params['progress_bar_callback']
model.configure_optimizers = self._params['configure_optimizers']

Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class CurrentTestModel(
)

changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find',
'progress_bar_refresh_rate',
'accumulate_grad_batches',
'progress_bar_refresh_rate', 'early_stop_callback',
'accumulate_grad_batches', 'enable_early_stop',
'checkpoint_callback']
attributes_before = {}
for ca in changed_attributes:
Expand Down