Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate TrainerOptimizersMixin and move functionality to core/optimizer.py #11155

Merged
merged 38 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
68d60fa
refactor init_optimizers
daniellepintz Dec 18, 2021
e7e3b59
fix tests
daniellepintz Dec 18, 2021
628e4a0
fix tests
daniellepintz Dec 18, 2021
68caa57
fix another test
daniellepintz Dec 19, 2021
0396959
attempt to fix deepspeed test
daniellepintz Dec 19, 2021
b645134
addr comments
daniellepintz Dec 19, 2021
47bdae1
deprecate optimizer mixin
daniellepintz Dec 19, 2021
aff9de3
addr comments
daniellepintz Dec 20, 2021
74a3033
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 20, 2021
3d1faaa
move _convert_to_lightning_optimizer to core/optimizer.py
daniellepintz Dec 20, 2021
8c1267f
fix typing and docstring
daniellepintz Dec 21, 2021
4a1ee1a
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 21, 2021
108c4a2
remove strategy refactor
daniellepintz Dec 21, 2021
a768823
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
b27e2e1
small fix
daniellepintz Dec 21, 2021
e8d6cf6
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
7d72f56
fix mypy
daniellepintz Dec 21, 2021
4fda88f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
be56fcd
addr comment
daniellepintz Dec 21, 2021
0da03c8
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
779c0a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 21, 2021
362a1b7
fix test
daniellepintz Dec 21, 2021
81c6788
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 21, 2021
493a582
addr comments
daniellepintz Dec 22, 2021
b20d6f4
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
1c30527
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
8b80304
fix merge conflicts from strategy
daniellepintz Dec 22, 2021
5397d2f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2021
5da0649
addr comments
daniellepintz Dec 22, 2021
f320402
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 22, 2021
119bc58
fix whitespace
daniellepintz Dec 22, 2021
ecb3c60
fix weakref test
daniellepintz Dec 22, 2021
fd1b14d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Dec 22, 2021
8ba6f37
add comment
daniellepintz Dec 22, 2021
d2c6b6e
addr comments
daniellepintz Dec 22, 2021
bb6ba78
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2021
41bbfb8
fix mypy
daniellepintz Dec 22, 2021
126533d
Merge branch 'optimizers_mixin' of github.com:daniellepintz/pytorch-l…
daniellepintz Dec 22, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ module = [
"pytorch_lightning.trainer.connectors.checkpoint_connector",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.data_loading",
"pytorch_lightning.trainer.optimizers",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -182,7 +181,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
)
default_scheduler_cfg = _get_default_scheduler_config()
default_scheduler_cfg = pl.LightningModule._get_default_scheduler_config()
assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
default_scheduler_cfg["scheduler"] = self._swa_scheduler

Expand Down
184 changes: 182 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union

import torch
from torch import ScriptModule, Tensor
from torch import optim, ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
Expand All @@ -34,7 +34,7 @@
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _MockOptimizer, LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import (
Expand Down Expand Up @@ -1969,3 +1969,183 @@ def _register_sharded_tensor_state_dict_hooks_if_available(self) -> None:

self._register_state_dict_hook(state_dict_hook)
self._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True)

def init_optimizers_and_lr_schedulers(self) -> Tuple[List, List, List]:
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
optim_conf = self.configure_optimizers()
if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies, monitor = self._configure_optimizers(optim_conf)
lr_schedulers = self._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies

@staticmethod
def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
) -> Tuple[List, List, List, Optional[str]]:
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)

lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
"Unknown configuration for model optimizers."
" Output from `model.configure_optimizers()` should either be:\n"
" * `torch.optim.Optimizer`\n"
" * [`torch.optim.Optimizer`]\n"
" * ([`torch.optim.Optimizer`], [`torch.optim.lr_scheduler`])\n"
' * {"optimizer": `torch.optim.Optimizer`, (optional) "lr_scheduler": `torch.optim.lr_scheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor

@staticmethod
def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = LightningModule._get_default_scheduler_config()
for scheduler in schedulers:
if is_manual_optimization:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
" You need to call `lr_scheduler.step()` manually in manual optimization.",
category=RuntimeWarning,
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
)
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
raise MisconfigurationException(
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler["reduce_on_plateau"] = isinstance(
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
raise MisconfigurationException(
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
rank_zero_warn(
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers

@staticmethod
def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}


def _validate_scheduler_optimizer(optimizers, lr_schedulers):
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)


def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
)
30 changes: 29 additions & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Dict, Generator, Optional
from weakref import proxy

import torch
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -162,3 +163,30 @@ def closure_dis():
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.training_type_plugin.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
`configure_optimizers`."""
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
super().__init__([torch.zeros(1)], {})

def add_param_group(self, param_group: Dict[Any, Any]) -> None:
pass # Do Nothing

def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
pass # Do Nothing

def state_dict(self) -> Dict[Any, Any]:
daniellepintz marked this conversation as resolved.
Show resolved Hide resolved
return {} # Return Empty

def step(self, closure: Callable = None) -> None:
if closure is not None:
closure()

def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
pass # Do Nothing

def __repr__(self) -> str:
return "No Optimizer"
15 changes: 3 additions & 12 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -444,16 +443,14 @@ def init_deepspeed(self):
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
optimizers, schedulers, optimizer_frequencies = self.lightning_module.init_optimizers_and_lr_schedulers()
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
)
return (
optimizers[0],
schedulers[0] if schedulers else _get_default_scheduler_config(),
schedulers[0] if schedulers else pl.LightningModule._get_default_scheduler_config(),
optimizer_frequencies[0] if optimizer_frequencies else None,
)

Expand All @@ -463,7 +460,7 @@ def zero_stage_3(self) -> bool:

def _initialize_deepspeed_train(self, model):
if "optimizer" in self.config:
optimizer, lr_scheduler = None, _get_default_scheduler_config()
optimizer, lr_scheduler = None, pl.LightningModule._get_default_scheduler_config()
else:
rank_zero_info(
"You have not specified an optimizer or scheduler within the DeepSpeed config."
Expand Down Expand Up @@ -562,12 +559,6 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule") -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`
return [], [], [] # empty optimizers, schedulers and frequencies

@property
def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
"""
if trainer.state.fn not in (TrainerFn.FITTING, TrainerFn.TUNING):
return
optimizers, lr_schedulers, optimizer_frequencies = self.init_optimizers(
trainer=trainer, model=self.lightning_module
)
optimizers, lr_schedulers, optimizer_frequencies = self.lightning_module.init_optimizers_and_lr_schedulers()
self.optimizers = optimizers
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies
Expand Down Expand Up @@ -377,9 +375,6 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
"""
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)

@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
"""Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin
Expand Down
Loading