From 832f4e12de8f36adddfda04eda86dce7fea95491 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 29 Mar 2020 11:47:13 +0100 Subject: [PATCH 1/6] Add warning when using default optimizer --- pytorch_lightning/core/lightning.py | 2 -- .../trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 6 ++-- pytorch_lightning/trainer/trainer.py | 11 ++++-- tests/base/utils.py | 2 +- tests/models/test_gpu.py | 34 +++++++++++++------ 6 files changed, 38 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 4fe23ef7fcb7d..100f1bc789088 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -10,7 +10,6 @@ import torch.distributed as torch_distrib from torch import Tensor from torch.nn.parallel import DistributedDataParallel -from torch.optim import Adam from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -983,7 +982,6 @@ def configure_optimizers(self): } """ - return Adam(self.parameters(), lr=1e-3) def optimizer_step( self, diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index b90d089f2a80d..7fadf4570c717 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -305,7 +305,7 @@ def ddp_train(self, gpu_idx, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers = self.init_optimizers(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 3eee333065833..90fe9e29eec5a 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -459,7 +459,7 @@ def single_gpu_train(self, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers = self.init_optimizers(model) if self.use_amp: # An example @@ -485,7 +485,7 @@ def tpu_train(self, tpu_core_idx, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers = self.init_optimizers(model) # init 16 bit for TPU if self.precision == 16: @@ -504,7 +504,7 @@ def dp_train(self, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers = self.init_optimizers(model) model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8a4d4dc6cd3e0..c36985655b54e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -10,6 +10,7 @@ import torch.distributed as torch_distrib import torch.multiprocessing as mp from torch import optim +from torch.optim import Adam from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -714,7 +715,7 @@ def fit( # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers = self.init_optimizers(model) self.run_pretrain_routine(model) @@ -760,8 +761,14 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da def init_optimizers( self, - optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]] + model: LightningModule ) -> Tuple[List, List]: + optimizers = model.configure_optimizers() + + if optimizers is None: + warnings.warn('`LightningModule.configure_optimizers` is not overriden or returned `None`,' + 'defaulting to Adam optimizer with `lr=1e-3`', UserWarning) + optimizers = Adam(model.parameters(), lr=1e-3) # single output, single optimizer if isinstance(optimizers, Optimizer): diff --git a/tests/base/utils.py b/tests/base/utils.py index c6b8e3ceaf67a..b4e89b040d559 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -82,7 +82,7 @@ def run_model_test(trainer_options, model, on_gpu=True): if trainer.use_ddp or trainer.use_ddp2: # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model - trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers()) + trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model) # test HPC loading / saving trainer.hpc_save(save_dir, logger) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 9c684ca6bfbf0..923b97fbeb427 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -96,32 +96,46 @@ def test_optimizer_return_options(): # single optimizer opt_a = torch.optim.Adam(model.parameters(), lr=0.002) opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - optim, lr_sched = trainer.init_optimizers(opt_a) + + model.configure_optimizers = lambda: opt_a + optim, lr_sched = trainer.init_optimizers(model) assert len(optim) == 1 and len(lr_sched) == 0 # opt tuple - opts = (opt_a, opt_b) - optim, lr_sched = trainer.init_optimizers(opts) - assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] + model.configure_optimizers = lambda: (opt_a, opt_b) + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b assert len(lr_sched) == 0 # opt list - opts = [opt_a, opt_b] - optim, lr_sched = trainer.init_optimizers(opts) - assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] + model.configure_optimizers = lambda: [opt_a, opt_b] + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b assert len(lr_sched) == 0 # opt tuple of lists scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10) - opts = ([opt_a], [scheduler]) - optim, lr_sched = trainer.init_optimizers(opts) + model.configure_optimizers = lambda: ([opt_a], [scheduler]) + optim, lr_sched = trainer.init_optimizers(model) assert len(optim) == 1 and len(lr_sched) == 1 - assert optim[0] == opts[0][0] and \ + assert optim[0] == opt_a and \ lr_sched[0] == dict(scheduler=scheduler, interval='epoch', frequency=1, reduce_on_plateau=False, monitor='val_loss') +def test_default_optimizer_warning(): + tutils.reset_seed() + + trainer = Trainer() + model, hparams = tutils.get_default_model() + + model.configure_optimizers = lambda: None + + with pytest.warns(UserWarning, match='Adam optimizer with `lr=1e-3`'): + _, __ = trainer.init_optimizers(model) + + def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" tutils.reset_seed() From 9547557ed746fc102ac273c669a30ff47e888552 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 29 Mar 2020 11:50:59 +0100 Subject: [PATCH 2/6] Refactor optimizer tests to test_optimizers --- tests/models/test_gpu.py | 49 ------------------------------ tests/trainer/test_optimizers.py | 52 ++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 923b97fbeb427..6d6d344eb70ce 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -87,55 +87,6 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): assert result == 1, "DDP doesn't work with dataloaders passed to fit()." -def test_optimizer_return_options(): - tutils.reset_seed() - - trainer = Trainer() - model, hparams = tutils.get_default_model() - - # single optimizer - opt_a = torch.optim.Adam(model.parameters(), lr=0.002) - opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - - model.configure_optimizers = lambda: opt_a - optim, lr_sched = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 0 - - # opt tuple - model.configure_optimizers = lambda: (opt_a, opt_b) - optim, lr_sched = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 - - # opt list - model.configure_optimizers = lambda: [opt_a, opt_b] - optim, lr_sched = trainer.init_optimizers(model) - assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b - assert len(lr_sched) == 0 - - # opt tuple of lists - scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10) - model.configure_optimizers = lambda: ([opt_a], [scheduler]) - optim, lr_sched = trainer.init_optimizers(model) - assert len(optim) == 1 and len(lr_sched) == 1 - assert optim[0] == opt_a and \ - lr_sched[0] == dict(scheduler=scheduler, interval='epoch', - frequency=1, reduce_on_plateau=False, - monitor='val_loss') - - -def test_default_optimizer_warning(): - tutils.reset_seed() - - trainer = Trainer() - model, hparams = tutils.get_default_model() - - model.configure_optimizers = lambda: None - - with pytest.warns(UserWarning, match='Adam optimizer with `lr=1e-3`'): - _, __ = trainer.init_optimizers(model) - - def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" tutils.reset_seed() diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 4de3580eba22f..4f87a1bcc0460 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -1,3 +1,6 @@ +import pytest +import torch + import tests.base.utils as tutils from pytorch_lightning import Trainer @@ -173,3 +176,52 @@ class CurrentTestModel( dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss', interval='epoch', frequency=1, reduce_on_plateau=True), \ 'lr schduler was not correctly converted to dict' + + +def test_optimizer_return_options(): + tutils.reset_seed() + + trainer = Trainer() + model, hparams = tutils.get_default_model() + + # single optimizer + opt_a = torch.optim.Adam(model.parameters(), lr=0.002) + opt_b = torch.optim.SGD(model.parameters(), lr=0.002) + + model.configure_optimizers = lambda: opt_a + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == 0 + + # opt tuple + model.configure_optimizers = lambda: (opt_a, opt_b) + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b + assert len(lr_sched) == 0 + + # opt list + model.configure_optimizers = lambda: [opt_a, opt_b] + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b + assert len(lr_sched) == 0 + + # opt tuple of lists + scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10) + model.configure_optimizers = lambda: ([opt_a], [scheduler]) + optim, lr_sched = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == 1 + assert optim[0] == opt_a and \ + lr_sched[0] == dict(scheduler=scheduler, interval='epoch', + frequency=1, reduce_on_plateau=False, + monitor='val_loss') + + +def test_default_optimizer_warning(): + tutils.reset_seed() + + trainer = Trainer() + model, hparams = tutils.get_default_model() + + model.configure_optimizers = lambda: None + + with pytest.warns(UserWarning, match='Adam optimizer with `lr=1e-3`'): + _, __ = trainer.init_optimizers(model) From ceff1d935a4aee103d3fc3e857a8fd1341d53379 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 29 Mar 2020 12:49:33 +0100 Subject: [PATCH 3/6] Remove default optimizer, add option to use no optimizer --- pytorch_lightning/core/lightning.py | 11 +-- pytorch_lightning/trainer/optimizers.py | 102 ++++++++++++++++++++++++ pytorch_lightning/trainer/trainer.py | 67 +--------------- tests/base/__init__.py | 3 +- tests/base/mixins.py | 5 ++ tests/trainer/test_optimizers.py | 35 +++++++- 6 files changed, 149 insertions(+), 74 deletions(-) create mode 100644 pytorch_lightning/trainer/optimizers.py diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 100f1bc789088..cdd093d393b42 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -904,24 +904,25 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers - def configure_optimizers(self) -> Union[ + def configure_optimizers(self) -> Optional[Union[ Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List] - ]: + ]]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. - If you don't define this method Lightning will automatically use Adam(lr=1e-3) + If you don't define this method Lightning will run **without any optimizer**. - Return: any of these 3 options: + Return: any of these 4 options: - Single optimizer - List or Tuple - List of optimizers - Two lists - The first list has multiple optimizers, the second a list of LR schedulers + - None - Fit will run without any optimizer Examples: .. code-block:: python - # most cases (default if not defined) + # most cases def configure_optimizers(self): opt = Adam(self.parameters(), lr=1e-3) return opt diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py new file mode 100644 index 0000000000000..045ae3acb6757 --- /dev/null +++ b/pytorch_lightning/trainer/optimizers.py @@ -0,0 +1,102 @@ +import warnings +from abc import ABC +from typing import List, Tuple + +import torch +from torch import optim +from torch.optim.optimizer import Optimizer + +from pytorch_lightning.core.lightning import LightningModule + + +class TrainerOptimizersMixin(ABC): + + def init_optimizers( + self, + model: LightningModule + ) -> Tuple[List, List]: + optimizers = model.configure_optimizers() + + if optimizers is None: + warnings.warn('`LightningModule.configure_optimizers` is not overriden or returned `None`,' + 'this fit will run with no optimizer', UserWarning) + optimizers = _MockOptimizer() + + # single output, single optimizer + if isinstance(optimizers, Optimizer): + return [optimizers], [] + + # two lists, optimizer + lr schedulers + elif len(optimizers) == 2 and isinstance(optimizers[0], list): + optimizers, lr_schedulers = optimizers + lr_schedulers = self.configure_schedulers(lr_schedulers) + return optimizers, lr_schedulers + + # single list or tuple, multiple optimizer + elif isinstance(optimizers, (list, tuple)): + return optimizers, [] + + # unknown configuration + else: + raise ValueError('Unknown configuration for model optimizers. Output' + 'from model.configure_optimizers() should either be:' + '* single output, single torch.optim.Optimizer' + '* single output, list of torch.optim.Optimizer' + '* two outputs, first being a list of torch.optim.Optimizer', + 'second being a list of torch.optim.lr_scheduler') + + def configure_schedulers(self, schedulers: list): + # Convert each scheduler into dict sturcture with relevant information + lr_schedulers = [] + default_config = {'interval': 'epoch', # default every epoch + 'frequency': 1, # default every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau + for scheduler in schedulers: + if isinstance(scheduler, dict): + if 'scheduler' not in scheduler: + raise ValueError(f'Lr scheduler should have key `scheduler`', + ' with item being a lr scheduler') + scheduler['reduce_on_plateau'] = isinstance( + scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) + + lr_schedulers.append({**default_config, **scheduler}) + + elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + lr_schedulers.append({**default_config, 'scheduler': scheduler, + 'reduce_on_plateau': True}) + + elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): + lr_schedulers.append({**default_config, 'scheduler': scheduler}) + else: + raise ValueError(f'Input {scheduler} to lr schedulers ' + 'is a invalid input.') + return lr_schedulers + + +class _MockOptimizer(Optimizer): + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` + is returned from `configure_optimizers`. + """ + + def __init__(self): + super().__init__([torch.zeros(1)], {}) + + def add_param_group(self, param_group): + pass # Do Nothing + + def load_state_dict(self, state_dict): + pass # Do Nothing + + def state_dict(self): + return {} # Return Empty + + def step(self, closure=None): + if closure is not None: + closure() + + def zero_grad(self): + pass # Do Nothing + + def __repr__(self): + return 'No Optimizer' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c36985655b54e..fbf149c01ece8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -9,9 +9,6 @@ import torch import torch.distributed as torch_distrib import torch.multiprocessing as mp -from torch import optim -from torch.optim import Adam -from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -31,6 +28,7 @@ from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin +from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -55,6 +53,7 @@ class Trainer( TrainerIOMixin, + TrainerOptimizersMixin, TrainerDPMixin, TrainerDDPMixin, TrainerLoggingMixin, @@ -759,68 +758,6 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da model.test_dataloader = _PatchDataLoader(test_dataloaders) - def init_optimizers( - self, - model: LightningModule - ) -> Tuple[List, List]: - optimizers = model.configure_optimizers() - - if optimizers is None: - warnings.warn('`LightningModule.configure_optimizers` is not overriden or returned `None`,' - 'defaulting to Adam optimizer with `lr=1e-3`', UserWarning) - optimizers = Adam(model.parameters(), lr=1e-3) - - # single output, single optimizer - if isinstance(optimizers, Optimizer): - return [optimizers], [] - - # two lists, optimizer + lr schedulers - elif len(optimizers) == 2 and isinstance(optimizers[0], list): - optimizers, lr_schedulers = optimizers - lr_schedulers = self.configure_schedulers(lr_schedulers) - return optimizers, lr_schedulers - - # single list or tuple, multiple optimizer - elif isinstance(optimizers, (list, tuple)): - return optimizers, [] - - # unknown configuration - else: - raise ValueError('Unknown configuration for model optimizers. Output' - 'from model.configure_optimizers() should either be:' - '* single output, single torch.optim.Optimizer' - '* single output, list of torch.optim.Optimizer' - '* two outputs, first being a list of torch.optim.Optimizer', - 'second being a list of torch.optim.lr_scheduler') - - def configure_schedulers(self, schedulers: list): - # Convert each scheduler into dict sturcture with relevant information - lr_schedulers = [] - default_config = {'interval': 'epoch', # default every epoch - 'frequency': 1, # default every epoch/batch - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler - 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau - for scheduler in schedulers: - if isinstance(scheduler, dict): - if 'scheduler' not in scheduler: - raise ValueError(f'Lr scheduler should have key `scheduler`', - ' with item being a lr scheduler') - scheduler['reduce_on_plateau'] = isinstance( - scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) - - lr_schedulers.append({**default_config, **scheduler}) - - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): - lr_schedulers.append({**default_config, 'scheduler': scheduler, - 'reduce_on_plateau': True}) - - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, 'scheduler': scheduler}) - else: - raise ValueError(f'Input {scheduler} to lr schedulers ' - 'is a invalid input.') - return lr_schedulers - def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. diff --git a/tests/base/__init__.py b/tests/base/__init__.py index 1e68469871d25..dc6f1c4f3b2d4 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -25,7 +25,8 @@ LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin + LightTestReduceLROnPlateauMixin, + LightTestNoneOptimizerMixin ) diff --git a/tests/base/mixins.py b/tests/base/mixins.py index 1a05049f44f5f..67caac1896f7a 100644 --- a/tests/base/mixins.py +++ b/tests/base/mixins.py @@ -688,6 +688,11 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] +class LightTestNoneOptimizerMixin: + def configure_optimizers(self): + return None + + def _get_output_metric(output, name): if isinstance(output, dict): val = output[name] diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 4f87a1bcc0460..42c81707b65e1 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -12,7 +12,8 @@ LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin + LightTestReduceLROnPlateauMixin, + LightTestNoneOptimizerMixin ) @@ -215,7 +216,7 @@ def test_optimizer_return_options(): monitor='val_loss') -def test_default_optimizer_warning(): +def test_none_optimizer_warning(): tutils.reset_seed() trainer = Trainer() @@ -223,5 +224,33 @@ def test_default_optimizer_warning(): model.configure_optimizers = lambda: None - with pytest.warns(UserWarning, match='Adam optimizer with `lr=1e-3`'): + with pytest.warns(UserWarning, match='will run with no optimizer'): _, __ = trainer.init_optimizers(model) + + +def test_none_optimizer(tmpdir): + tutils.reset_seed() + + class CurrentTestModel( + LightTestNoneOptimizerMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # verify training completed + assert result == 1 From 962cf8e894eca83554ece1b6fe6636772c24d4dd Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 29 Mar 2020 12:57:41 +0100 Subject: [PATCH 4/6] Update CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4a848bd4d822..0cece68140909 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,10 +16,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) +- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) ### Changed -- +- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) ### Deprecated From 993c6484f8d27869d9bfca84dbf7445c30517638 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sun, 29 Mar 2020 13:43:07 +0100 Subject: [PATCH 5/6] Update pytorch_lightning/trainer/optimizers.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/optimizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 045ae3acb6757..737e76ee079c7 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -19,7 +19,7 @@ def init_optimizers( if optimizers is None: warnings.warn('`LightningModule.configure_optimizers` is not overriden or returned `None`,' - 'this fit will run with no optimizer', UserWarning) + ' this fit will run with no optimizer', UserWarning) optimizers = _MockOptimizer() # single output, single optimizer From d188cd607455a8e0962fec90301cb53418a2fc3f Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 2 Apr 2020 13:37:58 +0100 Subject: [PATCH 6/6] Fix style --- pytorch_lightning/trainer/optimizers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 854f71f22d49f..ae074e907377c 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -27,7 +27,8 @@ def init_optimizers( return [optim_conf], [], [] # two lists, optimizer + lr schedulers - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ + and isinstance(optim_conf[0], list): optimizers, lr_schedulers = optim_conf lr_schedulers = self.configure_schedulers(lr_schedulers) return optimizers, lr_schedulers, [] @@ -44,9 +45,13 @@ def init_optimizers( elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] # take only lr wif exists and ot they are defined - not None - lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")] + lr_schedulers = [ + opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") + ] # take only freq wif exists and ot they are defined - not None - optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")] + optimizer_frequencies = [ + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") + ] # clean scheduler list if lr_schedulers: