Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer Frequencies logic, and new configure_optimizers #1269

Merged
merged 16 commits into from
Mar 31, 2020
30 changes: 26 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,20 @@ def configure_optimizers(self) -> Union[

If you don't define this method Lightning will automatically use Adam(lr=1e-3)

Return: any of these 3 options:
- Single optimizer
- List or Tuple - List of optimizers
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers
Return: any of these 5 options:
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
- Tuple of dictionaries as described, with an optional `frequency` key.

Note:
The `frequency` value is an int corresponding to the number of sequential batches
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
There is difference between passing multiple optimizers in a list,
and passing multiple optimizers in dictionaries with a frequency of 1:
In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.

Examples:
.. code-block:: python
Expand Down Expand Up @@ -949,6 +959,18 @@ def configure_optimizers(self):
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
return [gen_opt, dis_opt], [gen_sched, dis_sched]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this also have an example for scheduler?

{'optimizer': dis_opt, 'frequency': n_critic, 'lr_scheduler': Scheduler()}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda @asafmanor?
Also, amazing job :)

)

Note:

Some things to know:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def ddp_train(self, gpu_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())

# MODEL
# copy model to each gpu
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def single_gpu_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())

if self.use_amp:
# An example
Expand All @@ -485,7 +486,8 @@ def tpu_train(self, tpu_core_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())

# init 16 bit for TPU
if self.precision == 16:
Expand All @@ -504,7 +506,8 @@ def dp_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())

model.cuda(self.root_gpu)

Expand Down
63 changes: 46 additions & 17 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
import distutils

import torch
Expand Down Expand Up @@ -358,6 +358,7 @@ def __init__(
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
Expand Down Expand Up @@ -714,7 +715,8 @@ def fit(

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = \
self.init_optimizers(model.configure_optimizers())

self.run_pretrain_routine(model)

Expand Down Expand Up @@ -760,31 +762,57 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da

def init_optimizers(
self,
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
) -> Tuple[List, List]:
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[List, List, List]:

# single output, single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

# two lists, optimizer + lr schedulers
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers
return optimizers, lr_schedulers, []

# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
return [optimizer], lr_schedulers, []

# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
# take only lr wif exists and ot they are defined - not None
lr_schedulers = [opt_dict["lr_scheduler"] for opt_dict in optim_conf if opt_dict.get("lr_scheduler")]
# take only freq wif exists and ot they are defined - not None
optimizer_frequencies = [opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency")]

# clean scheduler list
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optimizers, (list, tuple)):
return optimizers, []
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

# unknown configuration
else:
raise ValueError('Unknown configuration for model optimizers. Output'
'from model.configure_optimizers() should either be:'
'* single output, single torch.optim.Optimizer'
'* single output, list of torch.optim.Optimizer'
'* two outputs, first being a list of torch.optim.Optimizer',
'second being a list of torch.optim.lr_scheduler')
raise ValueError(
'Unknown configuration for model optimizers. Output from `model.configure_optimizers()` should either be:'
' * single output, single `torch.optim.Optimizer`'
' * single output, list of `torch.optim.Optimizer`'
' * single output, a dictionary with `optimizer` key (`torch.optim.Optimizer`)'
' and an optional `lr_scheduler` key (`torch.optim.lr_scheduler`)'
' * two outputs, first being a list of `torch.optim.Optimizer` second being a list of `torch.optim.lr_scheduler`'
' * multiple outputs, dictionaries as described with an optional `frequency` key (int)')
Borda marked this conversation as resolved.
Show resolved Hide resolved

def configure_schedulers(self, schedulers: list):
# Convert each scheduler into dict sturcture with relevant information
Expand Down Expand Up @@ -971,6 +999,7 @@ class _PatchDataLoader(object):
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

Expand Down
17 changes: 15 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class TrainerTrainLoopMixin(ABC):
total_batches: int
truncated_bptt_steps: ...
optimizers: ...
optimizer_frequencies: ...
accumulate_grad_batches: int
use_amp: bool
track_grad_norm: ...
Expand Down Expand Up @@ -515,8 +516,7 @@ def run_training_batch(self, batch, batch_idx):
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx

# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
for opt_idx, optimizer in self._get_optimizers_iterable():
# make sure only the gradients of the current optimizer's paramaters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.optimizers) > 1:
Expand Down Expand Up @@ -617,6 +617,19 @@ def optimizer_closure():
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})

return 0, grad_norm_dic, all_log_metrics
Borda marked this conversation as resolved.
Show resolved Hide resolved

Borda marked this conversation as resolved.
Show resolved Hide resolved
def _get_optimizers_iterable(self):
if not self.optimizer_frequencies:
# call training_step once per optimizer
return list(enumerate(self.optimizers))

optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.total_batch_idx % optimizers_loop_length

# find optimzier index by looking for the first {item > current_place} in the cumsum list
opt_idx = np.argmax(optimizer_freq_cumsum > current_place_in_loop)
return [(opt_idx, self.optimizers[opt_idx])]

def run_training_teardown(self):
self.main_progress_bar.close()
Expand Down
3 changes: 2 additions & 1 deletion tests/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def run_model_test(trainer_options, model, on_gpu=True):
if trainer.use_ddp or trainer.use_ddp2:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers())
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
trainer.init_optimizers(pretrained_model.configure_optimizers())

# test HPC loading / saving
trainer.hpc_save(save_dir, logger)
Expand Down
54 changes: 38 additions & 16 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,52 @@ def test_optimizer_return_options():
# single optimizer
opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
optim, lr_sched = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0
scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)

# single optimizer
optim, lr_sched, freq = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0

# opt tuple
opts = (opt_a, opt_b)
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0
assert len(lr_sched) == 0 and len(freq) == 0

# opt list
opts = [opt_a, opt_b]
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0

# opt tuple of lists
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
opts = ([opt_a], [scheduler])
optim, lr_sched = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1
assert optim[0] == opts[0][0] and \
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
frequency=1, reduce_on_plateau=False,
monitor='val_loss')
assert len(lr_sched) == 0 and len(freq) == 0

# opt tuple of 2 lists
opts = ([opt_a], [scheduler_a])
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opts[0][0]
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')

# opt single dictionary
opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a}
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')

# opt multiple dictionaries with frequencies
opts = (
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
assert optim[0] == opt_a
assert lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False, monitor='val_loss')
assert freq == [1, 5]


def test_cpu_slurm_save_load(tmpdir):
Expand Down