Skip to content

Commit

Permalink
Remove default optimizer, add None optimizer option (Lightning-AI#1279)
Browse files Browse the repository at this point in the history
* Add warning when using default optimizer

* Refactor optimizer tests to test_optimizers

* Remove default optimizer, add option to use no optimizer

* Update CHANGELOG.md

* Update pytorch_lightning/trainer/optimizers.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Fix style

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and akarnachev committed Apr 3, 2020
1 parent c0199e5 commit fc59fca
Show file tree
Hide file tree
Showing 11 changed files with 267 additions and 168 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for non-primitive types in `hparams` for `TensorboardLogger` ([#1130](https://github.com/PyTorchLightning/pytorch-lightning/pull/1130))
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added option to run without an optimizer by returning `None` from `configure_optimizers`. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))

### Changed

- Changed default behaviour of `configure_optimizers` to use no optimizer rather than Adam. ([#1279](https://github.com/PyTorchLightning/pytorch-lightning/pull/1279))
- Added support for optimizer frequencies through `LightningModule.configure_optimizers()` ([#1269](https://github.com/PyTorchLightning/pytorch-lightning/pull/1269))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
Expand Down
17 changes: 7 additions & 10 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence

import torch
import torch.distributed as torch_distrib
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -905,21 +904,20 @@ def configure_apex(self, amp, model, optimizers, amp_level):

return model, optimizers

def configure_optimizers(self) -> Union[
Optimizer, List[Optimizer], Tuple[Optimizer, ...], Tuple[List[Optimizer], List]
]:
def configure_optimizers(self) -> Optional[Union[
Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]
]]:
r"""
Choose what optimizers and learning-rate schedulers to use in your optimization.
Normally you'd need one. But in the case of GANs or similar you might have multiple.
If you don't define this method Lightning will automatically use Adam(lr=1e-3)
Return: any of these 5 options:
Return: any of these 6 options:
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
- Tuple of dictionaries as described, with an optional `frequency` key.
- None - Fit will run without any optimizer.
Note:
The `frequency` value is an int corresponding to the number of sequential batches
Expand All @@ -932,7 +930,7 @@ def configure_optimizers(self) -> Union[
Examples:
.. code-block:: python
# most cases (default if not defined)
# most cases
def configure_optimizers(self):
opt = Adam(self.parameters(), lr=1e-3)
return opt
Expand Down Expand Up @@ -1005,7 +1003,6 @@ def configure_optimizers(self):
}
"""
return Adam(self.parameters(), lr=1e-3)

def optimizer_step(
self,
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,7 @@ def ddp_train(self, gpu_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# MODEL
# copy model to each gpu
Expand Down
9 changes: 3 additions & 6 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,8 +462,7 @@ def single_gpu_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

if self.use_amp:
# An example
Expand All @@ -489,8 +488,7 @@ def tpu_train(self, tpu_core_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

# init 16 bit for TPU
if self.precision == 16:
Expand All @@ -508,8 +506,7 @@ def dp_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

model.cuda(self.root_gpu)

Expand Down
135 changes: 135 additions & 0 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import warnings
from abc import ABC
from typing import List, Tuple

import torch
from torch import optim
from torch.optim.optimizer import Optimizer

from pytorch_lightning.core.lightning import LightningModule


class TrainerOptimizersMixin(ABC):

def init_optimizers(
self,
model: LightningModule
) -> Tuple[List, List, List]:
optim_conf = model.configure_optimizers()

if optim_conf is None:
warnings.warn('`LightningModule.configure_optimizers` returned `None`, '
'this fit will run with no optimizer', UserWarning)
optim_conf = _MockOptimizer()

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 \
and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers, []

# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
return [optimizer], lr_schedulers, []

# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [
opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")
]
# take only freq wif exists and ot they are defined - not None
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")
]

# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

# unknown configuration
else:
raise ValueError(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')

def configure_schedulers(self, schedulers: list):
# Convert each scheduler into dict sturcture with relevant information
lr_schedulers = []
default_config = {'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
for scheduler in schedulers:
if isinstance(scheduler, dict):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
return lr_schedulers


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`
is returned from `configure_optimizers`.
"""

def __init__(self):
super().__init__([torch.zeros(1)], {})

def add_param_group(self, param_group):
pass # Do Nothing

def load_state_dict(self, state_dict):
pass # Do Nothing

def state_dict(self):
return {} # Return Empty

def step(self, closure=None):
if closure is not None:
closure()

def zero_grad(self):
pass # Do Nothing

def __repr__(self):
return 'No Optimizer'
97 changes: 6 additions & 91 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import distutils
import inspect
import os
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
import distutils
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any

import torch
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch import optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

Expand All @@ -29,11 +27,12 @@
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.supporters import TensorRunningMean
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean

try:
from apex import amp
Expand All @@ -54,6 +53,7 @@

class Trainer(
TrainerIOMixin,
TrainerOptimizersMixin,
TrainerDPMixin,
TrainerDDPMixin,
TrainerLoggingMixin,
Expand Down Expand Up @@ -713,8 +713,7 @@ def fit(

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)

self.run_pretrain_routine(model)

Expand Down Expand Up @@ -758,90 +757,6 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da

model.test_dataloader = _PatchDataLoader(test_dataloaders)

def init_optimizers(
self,
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
) -> Tuple[List, List, List]:

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

# two lists, optimizer + lr schedulers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers, []

# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
return [optimizer], lr_schedulers, []

# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")]
# take only freq wif exists and ot they are defined - not None
optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")]

# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

# unknown configuration
else:
raise ValueError(
'Unknown configuration for model optimizers.'
' Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being'
' a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')

def configure_schedulers(self, schedulers: list):
# Convert each scheduler into dict sturcture with relevant information
lr_schedulers = []
default_config = {'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': 'val_loss'} # default value to monitor for ReduceLROnPlateau
for scheduler in schedulers:
if isinstance(scheduler, dict):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = isinstance(
scheduler['scheduler'], optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
return lr_schedulers

def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Expand Down
1 change: 1 addition & 0 deletions tests/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin,
LightTestReduceLROnPlateauMixin,
LightTestNoneOptimizerMixin,
LightZeroLenDataloader
)

Expand Down
Loading

0 comments on commit fc59fca

Please sign in to comment.