Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizer Frequencies logic, and new configure_optimizers #1269

Merged
merged 16 commits into from
Mar 31, 2020
29 changes: 25 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,10 +914,19 @@ def configure_optimizers(self) -> Union[

If you don't define this method Lightning will automatically use Adam(lr=1e-3)

Return: any of these 3 options:
- Single optimizer
- List or Tuple - List of optimizers
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers
Return: any of these 5 options:
- Single optimizer.
- List or Tuple - List of optimizers.
- Two lists - The first list has multiple optimizers, the second a list of LR schedulers.
- Dictionary, with an `optimizer` key and (optionally) a `lr_scheduler` key.
- Tuple of dictionaries as described, with an optional `frequncy` key.
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. note:: The `frequency` value is an int corresponding to the number of sequential batches
Borda marked this conversation as resolved.
Show resolved Hide resolved
optimized with the specific optimizer. It should be given to none or to all of the optimizers.
There is difference between passing multiple optimizers in a list,
and passing multiple optimizers in dictionaries with a frequency of 1:
In the former case, all optimizers will operate on the given batch in each optimization step.
In the latter, only one optimizer will operate on the given batch at every step.
Borda marked this conversation as resolved.
Show resolved Hide resolved

Examples:
.. code-block:: python
Expand Down Expand Up @@ -949,6 +958,18 @@ def configure_optimizers(self):
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called every epoch
return [gen_opt, dis_opt], [gen_sched, dis_sched]

# example with optimizer frequencies
# see training procedure in `Improved Training of Wasserstein GANs`, Algorithm 1
# https://arxiv.org/abs/1704.00028
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
n_critic = 5
return (
{'optimizer': dis_opt, 'frequency': n_critic},
{'optimizer': gen_opt, 'frequency': 1}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this also have an example for scheduler?

{'optimizer': dis_opt, 'frequency': n_critic, 'lr_scheduler': Scheduler()}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Borda @asafmanor?
Also, amazing job :)

)

Note:

Some things to know:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def ddp_train(self, gpu_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
opts = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

# MODEL
# copy model to each gpu
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,8 @@ def single_gpu_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
opts = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

if self.use_amp:
# An example
Expand All @@ -485,7 +486,8 @@ def tpu_train(self, tpu_core_idx, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
opts = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

# init 16 bit for TPU
if self.precision == 16:
Expand All @@ -504,7 +506,8 @@ def dp_train(self, model):

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
opts = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

model.cuda(self.root_gpu)

Expand Down
59 changes: 46 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any, Sequence
import distutils

import torch
Expand Down Expand Up @@ -358,6 +358,7 @@ def __init__(
self.disable_validation = False
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.total_batches = 0
Expand Down Expand Up @@ -714,7 +715,8 @@ def fit(

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
opts = self.init_optimizers(model.configure_optimizers())
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

self.run_pretrain_routine(model)

Expand Down Expand Up @@ -760,31 +762,61 @@ def __attach_dataloaders(self, model, train_dataloader, val_dataloaders, test_da

def init_optimizers(
self,
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
) -> Tuple[List, List]:
optim_conf: Union[Optimizer, Sequence[Optimizer], Dict, Sequence[Dict], Tuple[List, List]]
Borda marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[List, List, List]:

# single output, single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []
if isinstance(optim_conf, Optimizer):
return [optim_conf], [], []

# two lists, optimizer + lr schedulers
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
elif isinstance(optim_conf, (list, tuple)) and len(optim_conf) == 2 and isinstance(optim_conf[0], list):
optimizers, lr_schedulers = optim_conf
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers
return optimizers, lr_schedulers, []

# single dictionary
elif isinstance(optim_conf, dict):
optimizer = optim_conf["optimizer"]
lr_scheduler = optim_conf.get("lr_scheduler", [])
if lr_scheduler:
lr_schedulers = self.configure_schedulers([lr_scheduler])
return [optimizer], lr_schedulers, []

# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and isinstance(optim_conf[0], dict):
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
for optimizer_dict in optim_conf:
optimizers.append(optimizer_dict["optimizer"])
lr_schedulers.append(optimizer_dict.get("lr_scheduler", None))
optimizer_frequencies.append(optimizer_dict.get("frequency", None))
Borda marked this conversation as resolved.
Show resolved Hide resolved

# clean scheduler list
lr_schedulers = [x for x in lr_schedulers if x is not None]
Borda marked this conversation as resolved.
Show resolved Hide resolved
if lr_schedulers:
lr_schedulers = self.configure_schedulers(lr_schedulers)
# assert that if frequencies are present, they are given for all optimizers
optimizer_frequencies = [x for x in optimizer_frequencies if x is not None]
Borda marked this conversation as resolved.
Show resolved Hide resolved
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
return optimizers, lr_schedulers, optimizer_frequencies

# single list or tuple, multiple optimizer
elif isinstance(optimizers, (list, tuple)):
return optimizers, []
elif isinstance(optim_conf, (list, tuple)):
return list(optim_conf), [], []

# unknown configuration
else:
raise ValueError('Unknown configuration for model optimizers. Output'
'from model.configure_optimizers() should either be:'
'* single output, single torch.optim.Optimizer'
'* single output, list of torch.optim.Optimizer'
'* two outputs, first being a list of torch.optim.Optimizer',
'second being a list of torch.optim.lr_scheduler')
'* single output, a dictionary with `optimizer` key (torch.optim.Optimizer)'
'and an optional `lr_scheduler` key (torch.optim.lr_scheduler)'
'* two outputs, first being a list of torch.optim.Optimizer'
'second being a list of torch.optim.lr_scheduler'
'* multiple outputs, dictionaries as described'
'with an optional `frequency` key (int)')

def configure_schedulers(self, schedulers: list):
# Convert each scheduler into dict sturcture with relevant information
Expand Down Expand Up @@ -971,6 +1003,7 @@ class _PatchDataLoader(object):
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class TrainerTrainLoopMixin(ABC):
total_batches: int
truncated_bptt_steps: ...
optimizers: ...
optimizer_frequencies: ...
accumulate_grad_batches: int
use_amp: bool
track_grad_norm: ...
Expand Down Expand Up @@ -515,8 +516,22 @@ def run_training_batch(self, batch, batch_idx):
for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx

def get_optimizers_iterable():
asafmanor marked this conversation as resolved.
Show resolved Hide resolved
if not self.optimizer_frequencies:
return enumerate(self.optimizers)
Borda marked this conversation as resolved.
Show resolved Hide resolved

optimizer_freq_cumsum = np.cumsum(self.optimizer_frequencies)
Borda marked this conversation as resolved.
Show resolved Hide resolved
optimizers_loop_length = optimizer_freq_cumsum[-1]
current_place_in_loop = self.total_batch_idx % optimizers_loop_length

# find optimzier index by looking for the first {item > current_place} in the cumsum list
for opt_idx, v in enumerate(optimizer_freq_cumsum):
if v > current_place_in_loop:
# return an iterable list of one tuple
return [(opt_idx, self.optimizers[opt_idx])]
asafmanor marked this conversation as resolved.
Show resolved Hide resolved
Borda marked this conversation as resolved.
Show resolved Hide resolved

# call training_step once per optimizer
for opt_idx, optimizer in enumerate(self.optimizers):
for opt_idx, optimizer in get_optimizers_iterable():
# make sure only the gradients of the current optimizer's paramaters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if len(self.optimizers) > 1:
Expand Down
3 changes: 2 additions & 1 deletion tests/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def run_model_test(trainer_options, model, on_gpu=True):
if trainer.use_ddp or trainer.use_ddp2:
# on hpc this would work fine... but need to hack it for the purpose of the test
trainer.model = pretrained_model
trainer.optimizers, trainer.lr_schedulers = trainer.init_optimizers(pretrained_model.configure_optimizers())
opts = trainer.init_optimizers(pretrained_model.configure_optimizers())
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = opts
Borda marked this conversation as resolved.
Show resolved Hide resolved

# test HPC loading / saving
trainer.hpc_save(save_dir, logger)
Expand Down
49 changes: 37 additions & 12 deletions tests/models/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,30 +96,55 @@ def test_optimizer_return_options():
# single optimizer
opt_a = torch.optim.Adam(model.parameters(), lr=0.002)
opt_b = torch.optim.SGD(model.parameters(), lr=0.002)
optim, lr_sched = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0
scheduler_a = torch.optim.lr_scheduler.StepLR(opt_a, 10)
scheduler_b = torch.optim.lr_scheduler.StepLR(opt_b, 10)

# single optimizer
optim, lr_sched, freq = trainer.init_optimizers(opt_a)
assert len(optim) == 1 and len(lr_sched) == 0 and len(freq) == 0

# opt tuple
opts = (opt_a, opt_b)
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0
assert len(lr_sched) == 0 and len(freq) == 0

# opt list
opts = [opt_a, opt_b]
optim, lr_sched = trainer.init_optimizers(opts)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and optim[0] == opts[0] and optim[1] == opts[1]
assert len(lr_sched) == 0
assert len(lr_sched) == 0 and len(freq) == 0

# opt tuple of lists
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
opts = ([opt_a], [scheduler])
optim, lr_sched = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1
# opt tuple of 2 lists
opts = ([opt_a], [scheduler_a])
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opts[0][0] and \
Borda marked this conversation as resolved.
Show resolved Hide resolved
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
Borda marked this conversation as resolved.
Show resolved Hide resolved
frequency=1, reduce_on_plateau=False,
monitor='val_loss')
Borda marked this conversation as resolved.
Show resolved Hide resolved

# opt single dictionary
opts = {"optimizer": opt_a, "lr_scheduler": scheduler_a}
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1 and len(freq) == 0
assert optim[0] == opt_a and \
lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False,
monitor='val_loss')
Borda marked this conversation as resolved.
Show resolved Hide resolved

# opt multiple dictionaries with frequencies
opts = (
{"optimizer": opt_a, "lr_scheduler": scheduler_a, "frequency": 1},
{"optimizer": opt_b, "lr_scheduler": scheduler_b, "frequency": 5},
)
optim, lr_sched, freq = trainer.init_optimizers(opts)
assert len(optim) == 2 and len(lr_sched) == 2 and len(freq) == 2
assert optim[0] == opt_a and \
lr_sched[0] == dict(scheduler=scheduler_a, interval='epoch',
frequency=1, reduce_on_plateau=False,
monitor='val_loss')
Borda marked this conversation as resolved.
Show resolved Hide resolved
assert freq == [1, 5]


def test_cpu_slurm_save_load(tmpdir):
Expand Down