Skip to content

Commit

Permalink
Deprecate TrainerOptimizersMixin and move functionality to `core/op…
Browse files Browse the repository at this point in the history
…timizer.py` (#11155)
  • Loading branch information
daniellepintz authored Dec 23, 2021
1 parent 81301db commit a6a28e0
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 271 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `Trainer.should_rank_save_checkpoint` Trainer property ([#11068](https://github.com/PyTorchLightning/pytorch-lightning/pull/11068))


- Deprecated `TrainerOptimizersMixin` and moved functionality to `core/optimizer.py`([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


- Deprecated `TrainerCallbackHookMixin` ([#11148](https://github.com/PyTorchLightning/pytorch-lightning/pull/11148))

### Removed
Expand Down Expand Up @@ -351,6 +354,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))


- Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155))


## [1.5.7] - 2021-12-21

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.core.optimizer import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down
234 changes: 231 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import weakref
from contextlib import contextmanager
from typing import Any, Callable, Generator, Optional
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
from weakref import proxy

import torch
from torch import optim
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -54,7 +57,8 @@ def optimizer(self) -> Optimizer:
return self._optimizer

def _on_trainer_init(self, trainer: "pl.Trainer") -> None:
self._trainer = proxy(trainer)
# check if trainer is already of type weakproxy since we can't call proxy on a weakproxy
self._trainer = trainer if isinstance(trainer, weakref.ProxyType) else proxy(trainer)
for opt_idx, opt in enumerate(trainer.optimizers):
if opt == self._optimizer:
self._optimizer_idx = opt_idx
Expand Down Expand Up @@ -162,3 +166,227 @@ def closure_dis():
assert trainer is not None
with trainer.profiler.profile(profiler_action):
trainer.strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(model: "pl.LightningModule") -> Tuple[List, List, List]:
"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""
model.trainer._lightning_optimizers = None
optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)

if optim_conf is None:
rank_zero_warn(
"`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",
)
optim_conf = _MockOptimizer()

optimizers, lr_schedulers, optimizer_frequencies, monitor = _configure_optimizers(optim_conf)
lr_schedulers = _configure_schedulers(lr_schedulers, monitor, not model.automatic_optimization)
_validate_scheduler_optimizer(optimizers, lr_schedulers)
return optimizers, lr_schedulers, optimizer_frequencies


def _configure_optimizers(
optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple]
) -> Tuple[List, List, List, Optional[str]]:
optimizers, lr_schedulers, optimizer_frequencies = [], [], []
monitor = None

# single output, single optimizer
if isinstance(optim_conf, Optimizer):
optimizers = [optim_conf]
# two lists, optimizer + lr schedulers
elif (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
opt, sch = optim_conf
optimizers = opt
lr_schedulers = sch if isinstance(sch, list) else [sch]
# single dictionary
elif isinstance(optim_conf, dict):
_validate_optim_conf(optim_conf)
optimizers = [optim_conf["optimizer"]]
monitor = optim_conf.get("monitor", None)
lr_schedulers = [optim_conf["lr_scheduler"]] if "lr_scheduler" in optim_conf else []
# multiple dictionaries
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
for opt_dict in optim_conf:
_validate_optim_conf(opt_dict)
optimizers = [opt_dict["optimizer"] for opt_dict in optim_conf]
scheduler_dict = (
lambda scheduler, opt_idx: dict(scheduler, opt_idx=opt_idx)
if isinstance(scheduler, dict)
else {"scheduler": scheduler, "opt_idx": opt_idx}
)

lr_schedulers = [
scheduler_dict(opt_dict["lr_scheduler"], opt_idx)
for opt_idx, opt_dict in enumerate(optim_conf)
if "lr_scheduler" in opt_dict
]
optimizer_frequencies = [
opt_dict["frequency"] for opt_dict in optim_conf if opt_dict.get("frequency", None) is not None
]
# assert that if frequencies are present, they are given for all optimizers
if optimizer_frequencies and len(optimizer_frequencies) != len(optimizers):
raise ValueError("A frequency must be given to each optimizer.")
# single list or tuple, multiple optimizer
elif isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
optimizers = list(optim_conf)
# unknown configuration
else:
raise MisconfigurationException(
"Unknown configuration for model optimizers."
" Output from `model.configure_optimizers()` should be one of:\n"
" * `Optimizer`\n"
" * [`Optimizer`]\n"
" * ([`Optimizer`], [`_LRScheduler`])\n"
' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'
' * A list of the previously described dict format, with an optional "frequency" key (int)'
)
return optimizers, lr_schedulers, optimizer_frequencies, monitor


def _configure_schedulers(
schedulers: list, monitor: Optional[str], is_manual_optimization: bool
) -> List[Dict[str, Any]]:
"""Convert each scheduler into dict structure with relevant information."""
lr_schedulers = []
default_config = _get_default_scheduler_config()
# TODO: move is_manual_optimization check out of for loop
for scheduler in schedulers:
if is_manual_optimization:
if isinstance(scheduler, dict):
invalid_keys = {"interval", "frequency", "reduce_on_plateau", "monitor", "strict"}
keys_to_warn = [k for k in scheduler.keys() if k in invalid_keys]

if keys_to_warn:
rank_zero_warn(
f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."
" You need to call `lr_scheduler.step()` manually in manual optimization.",
category=RuntimeWarning,
)

scheduler = {key: scheduler[key] for key in scheduler if key not in invalid_keys}
lr_schedulers.append({**default_config, **scheduler})
else:
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
if isinstance(scheduler, dict):
# check provided keys
extra_keys = [k for k in scheduler.keys() if k not in default_config.keys()]
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the lr scheduler dict: {extra_keys}", category=RuntimeWarning
)
if "scheduler" not in scheduler:
raise MisconfigurationException(
'The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler'
)
if "interval" in scheduler and scheduler["interval"] not in ("step", "epoch"):
raise MisconfigurationException(
'The "interval" key in lr scheduler dict must be "step" or "epoch"'
f' but is "{scheduler["interval"]}"'
)
scheduler["reduce_on_plateau"] = isinstance(
scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau
)
if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None:
raise MisconfigurationException(
"The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."
' For example: {"optimizer": optimizer, "lr_scheduler":'
' {"scheduler": scheduler, "monitor": "your_loss"}}'
)
is_one_cycle = isinstance(scheduler["scheduler"], optim.lr_scheduler.OneCycleLR)
if is_one_cycle and scheduler.get("interval", "epoch") == "epoch":
rank_zero_warn(
"A `OneCycleLR` scheduler is using 'interval': 'epoch'."
" Are you sure you didn't mean 'interval': 'step'?",
category=RuntimeWarning,
)
lr_schedulers.append({**default_config, **scheduler})
elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
if monitor is None:
raise MisconfigurationException(
"`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"
" scheduler is used. For example:"
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, "scheduler": scheduler, "reduce_on_plateau": True, "monitor": monitor}
)
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, "scheduler": scheduler})
else:
raise ValueError(f'The provided lr scheduler "{scheduler}" is invalid')
return lr_schedulers


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
"scheduler": None,
"name": None, # no custom name
"interval": "epoch", # after epoch is over
"frequency": 1, # every epoch/batch
"reduce_on_plateau": False, # most often not ReduceLROnPlateau scheduler
"monitor": None, # value to monitor for ReduceLROnPlateau
"strict": True, # enforce that the monitor exists for ReduceLROnPlateau
"opt_idx": None, # necessary to store opt_idx when optimizer frequencies are specified
}


def _validate_scheduler_optimizer(optimizers: List[Any], lr_schedulers: List[Any]) -> None:
if any(sch["scheduler"].optimizer not in optimizers for sch in lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)


def _validate_optim_conf(optim_conf: Dict[str, Any]) -> None:
valid_keys = {"optimizer", "lr_scheduler", "frequency", "monitor"}
extra_keys = optim_conf.keys() - valid_keys
if extra_keys:
rank_zero_warn(
f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}", category=RuntimeWarning
)


def _convert_to_lightning_optimizers(trainer: "pl.Trainer") -> None:
def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
if not isinstance(optimizer, LightningOptimizer):
optimizer = LightningOptimizer(optimizer) # type: ignore [assignment]
optimizer._on_trainer_init(trainer)
return optimizer # type: ignore [return-value]

trainer._lightning_optimizers = { # type: ignore [assignment]
opt_idx: _convert_to_lightning_optimizer(opt) for opt_idx, opt in enumerate(trainer.optimizers)
}


class _MockOptimizer(Optimizer):
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from
`configure_optimizers`."""

def __init__(self) -> None:
super().__init__([torch.zeros(1)], {})

def add_param_group(self, param_group: Dict[Any, Any]) -> None:
pass # Do Nothing

def load_state_dict(self, state_dict: Dict[Any, Any]) -> None:
pass # Do Nothing

def state_dict(self) -> Dict[str, Any]:
return {} # Return Empty

def step(self, closure: Callable = None) -> None:
if closure is not None:
closure()

def zero_grad(self, set_to_none: Optional[bool] = False) -> None:
pass # Do Nothing

def __repr__(self) -> str:
return "No Optimizer"
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.nn.parallel.distributed import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
Expand Down Expand Up @@ -347,7 +347,7 @@ def _reinit_optimizers_with_post_localSGD(self, warmup_steps: int):
del optimizer
trainer = self.lightning_module.trainer
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)

def configure_ddp(self) -> None:
self.pre_configure_ddp()
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
from torch.optim.lr_scheduler import _LRScheduler

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _get_default_scheduler_config, _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.apply_func import apply_to_collection
Expand Down Expand Up @@ -446,9 +446,7 @@ def init_deepspeed(self):
self._initialize_deepspeed_inference(model)

def _init_optimizers(self) -> Tuple[Optimizer, Optional[Union[LRSchedulerTypeTuple]], Optional[int]]:
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
optimizers, schedulers, optimizer_frequencies = _init_optimizers_and_lr_schedulers(self.lightning_module)
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.optimizer import _convert_to_lightning_optimizers, LightningOptimizer
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, rank_zero_only
Expand Down Expand Up @@ -50,7 +50,7 @@ def configure_ddp(self) -> None:
optimizers=trainer.optimizers,
)
trainer.optimizers = optimizers
trainer.convert_to_lightning_optimizers()
_convert_to_lightning_optimizers(trainer)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Wraps the model and optimizers with fairscale components.
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/strategies/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins import TorchCheckpointIO
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
Expand Down Expand Up @@ -377,7 +378,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
return dataloader

def init_optimizers(self, trainer: "pl.Trainer", model: "pl.LightningModule"):
return trainer.init_optimizers(model)
return _init_optimizers_and_lr_schedulers(model)

@property
def restore_checkpoint_after_setup(self) -> bool:
Expand Down
Loading

0 comments on commit a6a28e0

Please sign in to comment.