Skip to content

Commit

Permalink
[feat] Add Trainer(stochastic_weight_avg=True/False) (#6038)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
4 people authored Feb 17, 2021
1 parent 8d7ac8f commit c9622ba
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 27 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `PL_TORCH_DISTRIBUTED_BACKEND` env variable to select backend ([#5981](https://github.com/PyTorchLightning/pytorch-lightning/pull/5981))


- Added `Trainer` flag to activate Stochastic Weight Averaging (SWA) `Trainer(stochastic_weight_avg=True)` ([#6038](https://github.com/PyTorchLightning/pytorch-lightning/pull/6038))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
'ModelPruning',
'ProgressBar',
'ProgressBarBase',
'ModelPruning',
'QuantizationAwareTraining',
'StochasticWeightAveraging',
]
40 changes: 32 additions & 8 deletions pytorch_lightning/callbacks/swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -96,8 +97,10 @@ def __init__(
raise MisconfigurationException(err_msg)

if (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
swa_lrs is not None and (
not isinstance(swa_lrs, (float, list)) or isinstance(swa_lrs, float) and swa_lrs <= 0
or isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs)
)
):
raise MisconfigurationException("The `swa_lrs` should be a positive float or a list of positive float.")

Expand Down Expand Up @@ -131,11 +134,13 @@ def pl_module_contains_batch_norm(pl_module: 'pl.LightningModule'):
def on_before_accelerator_backend_setup(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
# copy the model before moving it to accelerator device.
self._average_model = deepcopy(pl_module)

def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule'):
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers

if len(optimizers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `optimizer`.")
if len(optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

if len(lr_schedulers) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")
Expand All @@ -156,18 +161,37 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
self._average_model = self._average_model.to(self._device or pl_module.device)

optimizers = trainer.optimizers
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]

for param_group in optimizers[0].param_groups:
if self._swa_lrs is None:
initial_lr = param_group["lr"]

elif isinstance(self._swa_lrs, float):
initial_lr = self._swa_lrs

else:
initial_lr = self._swa_lrs[0]

param_group["initial_lr"] = initial_lr

self._swa_lrs = initial_lr

self._swa_scheduler = SWALR(
optimizers[0],
swa_lr=self._swa_lrs,
swa_lr=initial_lr,
anneal_epochs=self._annealing_epochs,
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1
)

rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0]["scheduler"] = self._swa_scheduler
else:
_scheduler_config = _get_default_scheduler_config()
_scheduler_config["scheduler"] = self._swa_scheduler
trainer.lr_schedulers.append(_scheduler_config)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

Expand Down
21 changes: 20 additions & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import os
from typing import List, Union

from pytorch_lightning.callbacks import Callback, ModelCheckpoint, ProgressBar, ProgressBarBase
from pytorch_lightning.callbacks import (
Callback,
ModelCheckpoint,
ProgressBar,
ProgressBarBase,
)
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -34,12 +39,14 @@ def on_trainer_init(
default_root_dir,
weights_save_path,
resume_from_checkpoint,
stochastic_weight_avg,
):
self.trainer.resume_from_checkpoint = resume_from_checkpoint

# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
self.trainer._stochastic_weight_avg = stochastic_weight_avg

# init callbacks
if isinstance(callbacks, Callback):
Expand All @@ -50,6 +57,9 @@ def on_trainer_init(
# pass through the required args to figure out defaults
self.configure_checkpoint_callbacks(checkpoint_callback)

# configure swa callback
self._configure_swa_callbacks()

# init progress bar
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)

Expand All @@ -76,6 +86,15 @@ def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpo
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
self.trainer.callbacks.append(ModelCheckpoint(dirpath=None, filename=None, mode='min'))

def _configure_swa_callbacks(self):
if not self.trainer._stochastic_weight_avg:
return

from pytorch_lightning.callbacks.swa import StochasticWeightAveraging
existing_swa = [cb for cb in self.trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]
if not existing_swa:
self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks

def configure_progress_bar(self, refresh_rate=None, process_position=0):
if os.getenv('COLAB_GPU') and refresh_rate is None:
# smaller refresh rate on colab causes crashes, choose a higher value
Expand Down
24 changes: 14 additions & 10 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict, Any

import torch
from torch import optim
Expand Down Expand Up @@ -98,15 +98,7 @@ def _convert_to_lightning_optimizer(trainer, optimizer):
def configure_schedulers(self, schedulers: list, monitor: Optional[str] = None):
# Convert each scheduler into dict structure with relevant information
lr_schedulers = []
default_config = {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': monitor, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
default_config = _get_default_scheduler_config()
for scheduler in schedulers:
if isinstance(scheduler, dict):
# check provided keys
Expand Down Expand Up @@ -185,3 +177,15 @@ def _validate_scheduler_optimizer(optimizers, lr_schedulers):
raise MisconfigurationException(
"Some schedulers are attatched with an optimizer that wasn't returned from `configure_optimizers`."
)


def _get_default_scheduler_config() -> Dict[str, Any]:
return {
'scheduler': None,
'name': None, # no custom name
'interval': 'epoch', # after epoch is over
'frequency': 1, # every epoch/batch
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
'monitor': None, # value to monitor for ReduceLROnPlateau
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
}
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
move_metrics_to_cpu: bool = False,
enable_pl_optimizer: bool = None, # todo: remove in v1.3
multiple_trainloader_mode: str = 'max_size_cycle',
stochastic_weight_avg: bool = False
):
r"""
Customize every aspect of training via flags
Expand Down Expand Up @@ -297,6 +298,10 @@ def __init__(
In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed,
and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets
reload when reaching the minimum length of datasets.
stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA)
<https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_`
"""
super().__init__()
self._running_stage = None
Expand Down Expand Up @@ -333,13 +338,8 @@ def __init__(
# init callbacks
# Declare attributes to be set in callback_connector on_trainer_init
self.callback_connector.on_trainer_init(
callbacks,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint,
callbacks, checkpoint_callback, progress_bar_refresh_rate, process_position, default_root_dir,
weights_save_path, resume_from_checkpoint, stochastic_weight_avg
)

# hook
Expand Down
29 changes: 29 additions & 0 deletions tests/callbacks/test_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,32 @@ def test_swa_raises():
StochasticWeightAveraging(swa_epoch_start=-1, swa_lrs=0.1)
with pytest.raises(MisconfigurationException, match="positive float or a list of positive float"):
StochasticWeightAveraging(swa_epoch_start=5, swa_lrs=[0.2, 1])


@pytest.mark.parametrize('stochastic_weight_avg', [False, True])
@pytest.mark.parametrize('use_callbacks', [False, True])
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
def test_trainer_and_stochastic_weight_avg(tmpdir, use_callbacks, stochastic_weight_avg):
"""Test to ensure SWA Callback is injected when `stochastic_weight_avg` is provided to the Trainer"""

class TestModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return optimizer

model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=StochasticWeightAveraging(swa_lrs=1e-3) if use_callbacks else None,
stochastic_weight_avg=stochastic_weight_avg,
limit_train_batches=4,
limit_val_batches=4,
max_epochs=2,
)
trainer.fit(model)
if use_callbacks or stochastic_weight_avg:
assert len([cb for cb in trainer.callbacks if isinstance(cb, StochasticWeightAveraging)]) == 1
assert trainer.callbacks[0]._swa_lrs == (1e-3 if use_callbacks else 0.1)
else:
assert all(not isinstance(cb, StochasticWeightAveraging) for cb in trainer.callbacks)

0 comments on commit c9622ba

Please sign in to comment.