diff --git a/CHANGELOG.md b/CHANGELOG.md index 3add6a9f77ea0..4d8d2a96ca10d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130)) - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) +- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) + +### Changed + +- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279)) - Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) - Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 99b01865d9a60..f23ab04523766 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -4,13 +4,12 @@ import warnings from abc import ABC, abstractmethod from argparse import Namespace -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence import torch import torch.distributed as torch_distrib from torch import Tensor from torch.nn.parallel import DistributedDataParallel -from torch.optim import Adam from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -905,21 +904,20 @@ def configure_apex(self, amp, model, optimizers, amp_level): return model, optimizers - def configure_optimizers(self) -> Union[ - Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List] - ]: + def configure_optimizers(self) -> Optional[Union[ + Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List] + ]]: r""" Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. - If you don't define this method Lightning will automatically use Adam(lr=1e-3) - - Return: any of these 5 options: + Return: any of these 6 options: - Single optimizer. - List or Tuple - List of optimizers. - Two lists - The first list has multiple optimizers, the second a list of LR schedulers. - Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key. - Tuple of dictionaries as described, with an optional `frequency` key. + - None - Fit will run without any optimizer. Note: The `frequency` value is an int corresponding to the number of sequential batches @@ -932,7 +930,7 @@ def configure_optimizers(self) -> Union[ Examples: .. code-block:: python - # most cases (default if not defined) + # most cases def configure_optimizers(self): opt = Adam(self.parameters(), lr=1e-3) return opt @@ -1005,7 +1003,6 @@ def configure_optimizers(self): } """ - return Adam(self.parameters(), lr=1e-3) def optimizer_step( self, diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index e592308a9eadf..dc3625e356865 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -304,8 +304,7 @@ def ddp_train(self, gpu_idx, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ - self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) # MODEL # copy model to each gpu diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index f48b919bc2675..43e4df038d79c 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -462,8 +462,7 @@ def single_gpu_train(self, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ - self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) if self.use_amp: # An example @@ -489,8 +488,7 @@ def tpu_train(self, tpu_core_idx, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ - self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) # init 16 bit for TPU if self.precision == 16: @@ -508,8 +506,7 @@ def dp_train(self, model): # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ - self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) model.cuda(self.root_gpu) diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py new file mode 100644 index 0000000000000..ae074e907377c --- /dev/null +++ b/pytorch_lightning/trainer/optimizers.py @@ -0,0 +1,135 @@ +import warnings +from abc import ABC +from typing import List, Tuple + +import torch +from torch import optim +from torch.optim.optimizer import Optimizer + +from pytorch_lightning.core.lightning import LightningModule + + +class TrainerOptimizersMixin(ABC): + + def init_optimizers( + self, + model: LightningModule + ) -> Tuple[List, List, List]: + optim_conf = model.configure_optimizers() + + if optim_conf is None: + warnings.warn('`LightningModule.configure_optimizers` returned `None`, ' + 'this fit will run with no optimizer', UserWarning) + optim_conf = _MockOptimizer() + + # single output, single optimizer + if isinstance(optim_conf, Optimizer): + return [optim_conf], [], [] + + # two lists, optimizer + lr schedulers + elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \ + and isinstance(optim_conf[0], list): + optimizers, lr_schedulers = optim_conf + lr_schedulers = self.configure_schedulers(lr_schedulers) + return optimizers, lr_schedulers, [] + + # single dictionary + elif isinstance(optim_conf, dict): + optimizer = optim_conf["optimizer"] + lr_scheduler = optim_conf.get("lr_scheduler", []) + if lr_scheduler: + lr_schedulers = self.configure_schedulers([lr_scheduler]) + return [optimizer], lr_schedulers, [] + + # multiple dictionaries + elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): + optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] + # take only lr wif exists and ot they are defined - not None + lr_schedulers = [ + opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler") + ] + # take only freq wif exists and ot they are defined - not None + optimizer_frequencies = [ + opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency") + ] + + # clean scheduler list + if lr_schedulers: + lr_schedulers = self.configure_schedulers(lr_schedulers) + # assert that if frequencies are present, they are given for all optimizers + if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): + raise ValueError("A frequency must be given to each optimizer.") + return optimizers, lr_schedulers, optimizer_frequencies + + # single list or tuple, multiple optimizer + elif isinstance(optim_conf, (list, tuple)): + return list(optim_conf), [], [] + + # unknown configuration + else: + raise ValueError( + 'Unknown configuration for model optimizers.' + ' Output from `model.configure_optimizers()` should either be:' + ' * single output, single `torch.optim.Optimizer`' + ' * single output, list of `torch.optim.Optimizer`' + ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' + ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' + ' * two outputs, first being a list of `torch.optim.Optimizer` second being' + ' a list of `torch.optim.lr_scheduler`' + ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') + + def configure_schedulers(self, schedulers: list): + # Convert each scheduler into dict sturcture with relevant information + lr_schedulers = [] + default_config = {'interval': 'epoch', # default every epoch + 'frequency': 1, # default every epoch/batch + 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler + 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau + for scheduler in schedulers: + if isinstance(scheduler, dict): + if 'scheduler' not in scheduler: + raise ValueError(f'Lr scheduler should have key `scheduler`', + ' with item being a lr scheduler') + scheduler['reduce_on_plateau'] = isinstance( + scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) + + lr_schedulers.append({**default_config, **scheduler}) + + elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): + lr_schedulers.append({**default_config, 'scheduler': scheduler, + 'reduce_on_plateau': True}) + + elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): + lr_schedulers.append({**default_config, 'scheduler': scheduler}) + else: + raise ValueError(f'Input {scheduler} to lr schedulers ' + 'is a invalid input.') + return lr_schedulers + + +class _MockOptimizer(Optimizer): + """The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` + is returned from `configure_optimizers`. + """ + + def __init__(self): + super().__init__([torch.zeros(1)], {}) + + def add_param_group(self, param_group): + pass # Do Nothing + + def load_state_dict(self, state_dict): + pass # Do Nothing + + def state_dict(self): + return {} # Return Empty + + def step(self, closure=None): + if closure is not None: + closure() + + def zero_grad(self): + pass # Do Nothing + + def __repr__(self): + return 'No Optimizer' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7bd724341f0ef..7dbcfac467b9e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,16 +1,14 @@ +import distutils import inspect import os import sys import warnings from argparse import ArgumentParser -from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence -import distutils +from typing import Union, Optional, List, Dict, Tuple, Iterable, Any import torch import torch.distributed as torch_distrib import torch.multiprocessing as mp -from torch import optim -from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from tqdm.auto import tqdm @@ -29,11 +27,12 @@ from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin +from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin +from pytorch_lightning.trainer.supporters import TensorRunningMean from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.supporters import TensorRunningMean try: from apex import amp @@ -54,6 +53,7 @@ class Trainer( TrainerIOMixin, + TrainerOptimizersMixin, TrainerDPMixin, TrainerDDPMixin, TrainerLoggingMixin, @@ -713,8 +713,7 @@ def fit( # CHOOSE OPTIMIZER # allow for lr schedulers as well - self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \ - self.init_optimizers(model.configure_optimizers()) + self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model) self.run_pretrain_routine(model) @@ -758,90 +757,6 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da model.test_dataloader = _PatchDataLoader(test_dataloaders) - def init_optimizers( - self, - optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]] - ) -> Tuple[List, List, List]: - - # single output, single optimizer - if isinstance(optim_conf, Optimizer): - return [optim_conf], [], [] - - # two lists, optimizer + lr schedulers - elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list): - optimizers, lr_schedulers = optim_conf - lr_schedulers = self.configure_schedulers(lr_schedulers) - return optimizers, lr_schedulers, [] - - # single dictionary - elif isinstance(optim_conf, dict): - optimizer = optim_conf["optimizer"] - lr_scheduler = optim_conf.get("lr_scheduler", []) - if lr_scheduler: - lr_schedulers = self.configure_schedulers([lr_scheduler]) - return [optimizer], lr_schedulers, [] - - # multiple dictionaries - elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict): - optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf] - # take only lr wif exists and ot they are defined - not None - lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")] - # take only freq wif exists and ot they are defined - not None - optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")] - - # clean scheduler list - if lr_schedulers: - lr_schedulers = self.configure_schedulers(lr_schedulers) - # assert that if frequencies are present, they are given for all optimizers - if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers): - raise ValueError("A frequency must be given to each optimizer.") - return optimizers, lr_schedulers, optimizer_frequencies - - # single list or tuple, multiple optimizer - elif isinstance(optim_conf, (list, tuple)): - return list(optim_conf), [], [] - - # unknown configuration - else: - raise ValueError( - 'Unknown configuration for model optimizers.' - ' Output from `model.configure_optimizers()` should either be:' - ' * single output, single `torch.optim.Optimizer`' - ' * single output, list of `torch.optim.Optimizer`' - ' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)' - ' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)' - ' * two outputs, first being a list of `torch.optim.Optimizer` second being' - ' a list of `torch.optim.lr_scheduler`' - ' * multiple outputs, dictionaries as described with an optional `frequency` key (int)') - - def configure_schedulers(self, schedulers: list): - # Convert each scheduler into dict sturcture with relevant information - lr_schedulers = [] - default_config = {'interval': 'epoch', # default every epoch - 'frequency': 1, # default every epoch/batch - 'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler - 'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau - for scheduler in schedulers: - if isinstance(scheduler, dict): - if 'scheduler' not in scheduler: - raise ValueError(f'Lr scheduler should have key `scheduler`', - ' with item being a lr scheduler') - scheduler['reduce_on_plateau'] = isinstance( - scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau) - - lr_schedulers.append({**default_config, **scheduler}) - - elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau): - lr_schedulers.append({**default_config, 'scheduler': scheduler, - 'reduce_on_plateau': True}) - - elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): - lr_schedulers.append({**default_config, 'scheduler': scheduler}) - else: - raise ValueError(f'Input {scheduler} to lr schedulers ' - 'is a invalid input.') - return lr_schedulers - def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. diff --git a/tests/base/__init__.py b/tests/base/__init__.py index 4d266c45880a7..a5728c0f77d85 100644 --- a/tests/base/__init__.py +++ b/tests/base/__init__.py @@ -26,6 +26,7 @@ LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, LightTestReduceLROnPlateauMixin, + LightTestNoneOptimizerMixin, LightZeroLenDataloader ) diff --git a/tests/base/mixins.py b/tests/base/mixins.py index 38eb2fe630f2b..02a9c16cfaa6b 100644 --- a/tests/base/mixins.py +++ b/tests/base/mixins.py @@ -695,6 +695,11 @@ def configure_optimizers(self): return [optimizer], [lr_scheduler] +class LightTestNoneOptimizerMixin: + def configure_optimizers(self): + return None + + def _get_output_metric(output, name): if isinstance(output, dict): val = output[name] diff --git a/tests/base/utils.py b/tests/base/utils.py index f1d26f0b83964..9eb42a5c5ee6b 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -83,7 +83,7 @@ def run_model_test(trainer_options, model, on_gpu=True): # on hpc this would work fine... but need to hack it for the purpose of the test trainer.model = pretrained_model trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \ - trainer.init_optimizers(pretrained_model.configure_optimizers()) + trainer.init_optimizers(pretrained_model) # test HPC loading / saving trainer.hpc_save(save_dir, logger) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index f783a2d32a519..0b821c74a6fd5 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -84,63 +84,6 @@ def test_ddp_all_dataloaders_passed_to_fit(tmpdir): assert result == 1, "DDP doesn't work with dataloaders passed to fit()." -def test_optimizer_return_options(): - tutils.reset_seed() - - trainer = Trainer() - model, hparams = tutils.get_default_model() - - # single optimizer - opt_a = torch.optim.Adam(model.parameters(), lr=0.002) - opt_b = torch.optim.SGD(model.parameters(), lr=0.002) - scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) - scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) - - # single optimizer - optim, lr_sched, freq = trainer.init_optimizers(opt_a) - assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 - - # opt tuple - opts = (opt_a, opt_b) - optim, lr_sched, freq = trainer.init_optimizers(opts) - assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] - assert len(lr_sched) == 0 and len(freq) == 0 - - # opt list - opts = [opt_a, opt_b] - optim, lr_sched, freq = trainer.init_optimizers(opts) - assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1] - assert len(lr_sched) == 0 and len(freq) == 0 - - # opt tuple of 2 lists - opts = ([opt_a], [scheduler_a]) - optim, lr_sched, freq = trainer.init_optimizers(opts) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 - assert optim[0] == opts[0][0] - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - - # opt single dictionary - opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a} - optim, lr_sched, freq = trainer.init_optimizers(opts) - assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 - assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - - # opt multiple dictionaries with frequencies - opts = ( - {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, - {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, - ) - optim, lr_sched, freq = trainer.init_optimizers(opts) - assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 - assert optim[0] == opt_a - assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', - frequency=1, reduce_on_plateau=False, monitor='val_loss') - assert freq == [1, 5] - - def test_cpu_slurm_save_load(tmpdir): """Verify model save/load/checkpoint on CPU.""" tutils.reset_seed() diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 4de3580eba22f..213529a9e4d00 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -1,3 +1,6 @@ +import pytest +import torch + import tests.base.utils as tutils from pytorch_lightning import Trainer @@ -9,7 +12,8 @@ LightTestOptimizerWithSchedulingMixin, LightTestMultipleOptimizersWithSchedulingMixin, LightTestOptimizersWithMixedSchedulingMixin, - LightTestReduceLROnPlateauMixin + LightTestReduceLROnPlateauMixin, + LightTestNoneOptimizerMixin ) @@ -173,3 +177,101 @@ class CurrentTestModel( dict(scheduler=trainer.lr_schedulers[0]['scheduler'], monitor='val_loss', interval='epoch', frequency=1, reduce_on_plateau=True), \ 'lr schduler was not correctly converted to dict' + + +def test_optimizer_return_options(): + tutils.reset_seed() + + trainer = Trainer() + model, hparams = tutils.get_default_model() + + # single optimizer + opt_a = torch.optim.Adam(model.parameters(), lr=0.002) + opt_b = torch.optim.SGD(model.parameters(), lr=0.002) + scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10) + scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10) + + # single optimizer + model.configure_optimizers = lambda: opt_a + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0 + + # opt tuple + model.configure_optimizers = lambda: (opt_a, opt_b) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b + assert len(lr_sched) == 0 and len(freq) == 0 + + # opt list + model.configure_optimizers = lambda: [opt_a, opt_b] + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 2 and optim[0] == opt_a and optim[1] == opt_b + assert len(lr_sched) == 0 and len(freq) == 0 + + # opt tuple of 2 lists + model.configure_optimizers = lambda: ([opt_a], [scheduler_a]) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + + # opt single dictionary + model.configure_optimizers = lambda: {"optimizer": opt_a, "lr_scheduler": scheduler_a} + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0 + assert optim[0] == opt_a + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + + # opt multiple dictionaries with frequencies + model.configure_optimizers = lambda: ( + {"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1}, + {"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5}, + ) + optim, lr_sched, freq = trainer.init_optimizers(model) + assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2 + assert optim[0] == opt_a + assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch', + frequency=1, reduce_on_plateau=False, monitor='val_loss') + assert freq == [1, 5] + + +def test_none_optimizer_warning(): + tutils.reset_seed() + + trainer = Trainer() + model, hparams = tutils.get_default_model() + + model.configure_optimizers = lambda: None + + with pytest.warns(UserWarning, match='will run with no optimizer'): + _, __, ___ = trainer.init_optimizers(model) + + +def test_none_optimizer(tmpdir): + tutils.reset_seed() + + class CurrentTestModel( + LightTestNoneOptimizerMixin, + LightTrainDataloader, + TestModelBase): + pass + + hparams = tutils.get_default_hparams() + model = CurrentTestModel(hparams) + + # logger file to get meta + trainer_options = dict( + default_save_path=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2 + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # verify training completed + assert result == 1