diff --git a/CHANGELOG.md b/CHANGELOG.md index 66bf99ba78a43..ad5ddbe22e49a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) +- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) + ### Deprecated ### Removed @@ -32,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wandb logger `global_step` affects other loggers ([#1492](https://github.com/PyTorchLightning/pytorch-lightning/issues/1485)) +- Fixed bugs that prevent lr finder to be used together with early stopping and validation dataloaders ([#1676](https://github.com/PyTorchLightning/pytorch-lightning/pull/1676)) + ## [0.7.5] - 2020-04-27 ### Changed @@ -76,7 +80,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Defines shared proc. rank, remove rank from instances (e.g. loggers) ([#1408](https://github.com/PyTorchLightning/pytorch-lightning/pull/1408)) - Updated semantic segmentation example with custom U-Net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371)) - Disabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600)) -- Updated LightningTemplateModel to look more like Colab example ([#1546](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) ### Deprecated diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index b0491c818dd5b..e664fd3cc47d0 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -53,6 +53,7 @@ def _run_lr_finder_internally(self, model: LightningModule): def lr_find(self, model: LightningModule, train_dataloader: Optional[DataLoader] = None, + val_dataloaders: Optional[DataLoader] = None, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, @@ -105,7 +106,7 @@ def lr_find(self, """ save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') - self._dump_params(model) + self._lr_finder_dump_params(model) # Prevent going into infinite loop self.auto_lr_find = False @@ -129,8 +130,10 @@ def lr_find(self, # Accumulation of gradients self.accumulate_grad_batches = num_accumulation_steps - # Disable standard checkpoint + # Disable standard checkpoint & early stopping self.checkpoint_callback = False + self.early_stop_callback = None + self.enable_early_stop = False # Required for saving the model self.optimizers, self.schedulers = [], [], @@ -150,7 +153,9 @@ def lr_find(self, model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0]) # Fit, lr & loss logged in callback - self.fit(model, train_dataloader=train_dataloader) + self.fit(model, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders) # Prompt if we stopped early if self.global_step != num_training: @@ -165,13 +170,13 @@ def lr_find(self, os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model - self._restore_params(model) + self._lr_finder_restore_params(model) if self.progress_bar_callback: self.progress_bar_callback.enable() return lr_finder - def _dump_params(self, model): + def _lr_finder_dump_params(self, model): # Prevent going into infinite loop self._params = { 'auto_lr_find': self.auto_lr_find, @@ -181,11 +186,13 @@ def _dump_params(self, model): 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, 'accumulate_grad_batches': self.accumulate_grad_batches, 'checkpoint_callback': self.checkpoint_callback, + 'early_stop_callback': self.early_stop_callback, + 'enable_early_stop': self.enable_early_stop, 'progress_bar_callback': self.progress_bar_callback, 'configure_optimizers': model.configure_optimizers, } - def _restore_params(self, model): + def _lr_finder_restore_params(self, model): self.auto_lr_find = self._params['auto_lr_find'] self.logger = self._params['logger'] self.callbacks = self._params['callbacks'] @@ -193,6 +200,8 @@ def _restore_params(self, model): self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] self.accumulate_grad_batches = self._params['accumulate_grad_batches'] self.checkpoint_callback = self._params['checkpoint_callback'] + self.early_stop_callback = self._params['early_stop_callback'] + self.enable_early_stop = self._params['enable_early_stop'] self.progress_bar_callback = self._params['progress_bar_callback'] model.configure_optimizers = self._params['configure_optimizers'] diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index ba6e9c336b130..ea2eca3d712ad 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -82,8 +82,8 @@ class CurrentTestModel( ) changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find', - 'progress_bar_refresh_rate', - 'accumulate_grad_batches', + 'progress_bar_refresh_rate', 'early_stop_callback', + 'accumulate_grad_batches', 'enable_early_stop', 'checkpoint_callback'] attributes_before = {} for ca in changed_attributes: