diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c35dcdca5034..1ad8752763794 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) +- Removed non-finite values from loss in `LRFinder` ([#1862](https://github.com/PyTorchLightning/pytorch-lightning/pull/1862)) + ### Deprecated ### Removed diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 7554a670bb988..6a76523f17fb4 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -321,8 +321,9 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1): """ try: - loss = self.results["loss"][skip_begin:-skip_end] - min_grad = (np.gradient(np.array(loss))).argmin() + loss = np.array(self.results["loss"][skip_begin:-skip_end]) + loss = loss[np.isfinite(loss)] + min_grad = np.gradient(loss).argmin() self._optimal_idx = min_grad + skip_begin return self.results["lr"][self._optimal_idx] except Exception: diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index b0f6b6d3a0829..fe4894c3e49de 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -124,7 +124,7 @@ def test_call_to_trainer_method(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=5, + max_epochs=5 ) lrfinder = trainer.lr_find(model, mode='linear') @@ -170,7 +170,7 @@ def test_suggestion_parameters_work(tmpdir): # logger file to get meta trainer = Trainer( default_save_path=tmpdir, - max_epochs=10, + max_epochs=10 ) lrfinder = trainer.lr_find(model) @@ -179,3 +179,24 @@ def test_suggestion_parameters_work(tmpdir): assert lr1 != lr2, \ 'Skipping parameter did not influence learning rate' + + +def test_suggestion_with_non_finite_values(tmpdir): + """ Test that non-finite values does not alter results """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(hparams) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=10 + ) + + lrfinder = trainer.lr_find(model) + before_lr = lrfinder.suggestion() + lrfinder.results['loss'][-1] = float('nan') + after_lr = lrfinder.suggestion() + + assert before_lr == after_lr, \ + 'Learning rate was altered because of non-finite loss values'