Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer(gradient_clip_algorithm='value'|'norm') #6123

Merged
merged 59 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
09ea112
add changelog
dhkim0225 Feb 22, 2021
c0e8064
add clip by value
dhkim0225 Feb 22, 2021
ca8e6fd
fix bug in training tricks.rst
dhkim0225 Feb 22, 2021
87f12c1
fix bug in trainer.rst
dhkim0225 Feb 22, 2021
8e43b8a
Update trainer.rst
dhkim0225 Feb 22, 2021
2bb9924
Update trainer.rst
dhkim0225 Feb 22, 2021
5b83f0d
Update CHANGELOG.md
dhkim0225 Feb 23, 2021
caafdf2
Update pytorch_lightning/plugins/precision/deepspeed_precision.py
dhkim0225 Feb 23, 2021
ca774b6
Update pytorch_lightning/utilities/enums.py
dhkim0225 Feb 23, 2021
5a741e2
yapf formatting
dhkim0225 Feb 23, 2021
0568e3a
update training tricks
dhkim0225 Feb 23, 2021
1a4e79e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Feb 26, 2021
4a813c1
Merge branch 'master' into feat/clip_grad_by_value
tchaton Feb 26, 2021
2f5cb3e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
f4275a2
update based on comment
dhkim0225 Mar 2, 2021
e92ec69
update based on comment
dhkim0225 Mar 2, 2021
ac701ce
Update pytorch_lightning/trainer/trainer.py
dhkim0225 Mar 2, 2021
bc20fa4
update based on comment
dhkim0225 Mar 2, 2021
b842210
Merge branch 'feat/clip_grad_by_value' of https://github.com/dhkim022…
dhkim0225 Mar 2, 2021
5ec2ebd
pep8
dhkim0225 Mar 2, 2021
d37fbbc
mypy
dhkim0225 Mar 2, 2021
952c778
mypy
dhkim0225 Mar 2, 2021
b8fdbe1
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
c4cccf0
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 3, 2021
6bd4793
Update docs/source/advanced/training_tricks.rst
dhkim0225 Mar 4, 2021
3aeba85
Update sharded_native_amp.py
dhkim0225 Mar 4, 2021
902a33c
Update test_sharded_parity.py
dhkim0225 Mar 4, 2021
7467616
update test codes
dhkim0225 Mar 4, 2021
5463830
Update test_tpu.py
dhkim0225 Mar 4, 2021
2e933d4
Update pytorch_lightning/trainer/connectors/training_trick_connector.py
dhkim0225 Mar 4, 2021
b1e26e6
Update test_trainer.py
dhkim0225 Mar 4, 2021
cedf5f6
Update enums.py
dhkim0225 Mar 4, 2021
f5bb45d
Update enums.py
dhkim0225 Mar 4, 2021
42fc5f6
Merge branch 'master' into feat/clip_grad_by_value
Borda Mar 4, 2021
e55b90c
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 5, 2021
308ce38
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 23, 2021
903f2e2
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
28c948a
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
177a1c9
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
fc23845
use subprocess check_call
dhkim0225 Mar 25, 2021
d99a650
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 25, 2021
f80aa8d
change order of horovod tests
dhkim0225 Mar 25, 2021
fb895b6
set max_epochs 2 in horovod test
dhkim0225 Mar 25, 2021
caa0bbf
remove clip_grad_val test from horovod-cpu
dhkim0225 Mar 25, 2021
f1f9015
remove "type: ignore"
dhkim0225 Mar 25, 2021
5dfe5ef
divide clip grad val test in horovod
dhkim0225 Mar 25, 2021
50a6c74
update based on comments
dhkim0225 Mar 25, 2021
c337b12
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
f7a4fda
bugfix
dhkim0225 Mar 25, 2021
48c3dd8
bugfix
dhkim0225 Mar 25, 2021
e7e3b47
revert some changes
dhkim0225 Mar 26, 2021
2997536
revert some changes
dhkim0225 Mar 26, 2021
fb34e84
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 26, 2021
8e665ec
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 27, 2021
9575774
Update tests/models/test_horovod.py
carmocca Mar 27, 2021
7c16f6a
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 29, 2021
fec189a
merge master
dhkim0225 Apr 6, 2021
1e80304
merge master
dhkim0225 Apr 6, 2021
4d5e05f
Delete signature test
carmocca Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


Expand Down
25 changes: 25 additions & 0 deletions benchmarks/test_sharded_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from tests.accelerators import DDPLauncher
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
Expand Down Expand Up @@ -115,6 +116,24 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
)


@RunIf(special=True, fairscale=True, min_gpus=2)
@DDPLauncher.run("--accelerator ddp --gpus 2 --precision 16")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def test_ddp_sharded_plugin_clip_gradients(tmpdir, args=None):
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
)
plugin_parity_test(
gpus=args.gpus,
precision=args.precision,
model_cls=SeedTrainLoaderModel,
gradient_clip_val=0.001,
gradient_clip_algorithm='value',
)


@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
Expand Down Expand Up @@ -245,6 +264,8 @@ def plugin_parity_test(
gpus: int = 0,
precision: int = 32,
max_percent_speed_diff: float = 0.1,
gradient_clip_val: float = 0,
gradient_clip_algorithm: str = 'norm',
):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Expand All @@ -271,6 +292,8 @@ def plugin_parity_test(
gpus=gpus,
precision=precision,
accelerator='ddp_spawn',
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)

max_memory_ddp, ddp_time = record_ddp_fit_model_stats(trainer=trainer, model=ddp_model, use_cuda=use_cuda)
Expand All @@ -285,6 +308,8 @@ def plugin_parity_test(
gpus=gpus,
precision=precision,
accelerator='ddp_sharded_spawn',
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
)
assert isinstance(trainer.training_type_plugin, DDPSpawnShardedPlugin)

Expand Down
10 changes: 8 additions & 2 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -39,6 +41,10 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

# clip gradients with value above 0.5
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')
Borda marked this conversation as resolved.
Show resolved Hide resolved

----------

Stochastic Weight Averaging
Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum

if TYPE_CHECKING:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -299,10 +299,15 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(optimizer, clip_val)
self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm)

def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -75,7 +76,12 @@ def backward(

return closure_loss

def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
def clip_gradients(
self,
optimizer: 'Optimizer',
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""
Expand Down
29 changes: 25 additions & 4 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType

if TYPE_CHECKING:
from torch.nn import Module
Expand All @@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin):
EPSILON: float = 1e-6
precision: Union[str, int] = 32

def __init__(self) -> None:
super().__init__()
self.clip_grad_funcs = {
GradClipAlgorithmType.VALUE: self.clip_grad_by_value,
GradClipAlgorithmType.NORM: self.clip_grad_by_norm,
}

def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
"""The master params of the model. Returns the plain model params here.
Maybe different in other precision plugins.
Expand Down Expand Up @@ -98,19 +106,32 @@ def pre_optimizer_step(
def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> None:
"""Hook to do something after each optimizer step."""

def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
"""Clips the gradients to a specific value"""
def clip_gradients(
self,
optimizer: 'Optimizer',
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Clips the gradients"""
clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm]
clip_grad_func(optimizer, clip_val)

def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make them private ?

Copy link
Contributor

@ananthsub ananthsub Mar 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton what's expected when other precision plugins override them, as the sharded native amp one does?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or protected :]

""" clip gradient by value """
parameters = list(self.master_params(optimizer))
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
""" clip gradient by norm """
# TODO: separate TPU case from here
if clip_val is None:
return

grad_clip_val = float(clip_val)

if grad_clip_val <= 0:
return

parameters = list(self.master_params(optimizer))

max_norm = grad_clip_val

if isinstance(parameters, torch.Tensor):
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self) -> None:
super().__init__()
self.scaler = ShardedGradScaler()

def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
""" Overrided function. Clip gradients by norm. """
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -23,6 +24,7 @@ def __init__(self, trainer):
def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
Expand All @@ -32,7 +34,12 @@ def on_trainer_init(
self.trainer.terminate_on_nan = terminate_on_nan

# gradient clipping
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
raise MisconfigurationException(
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
)
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
gradient_clip_algorithm: str = 'norm',
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
Expand Down Expand Up @@ -197,6 +198,8 @@ def __init__(

gradient_clip_val: 0 means don't clip.

gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm'

limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)

limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
Expand Down Expand Up @@ -341,7 +344,12 @@ def __init__(

# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
)
self.train_loop.on_trainer_init(
max_epochs,
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
DeviceType,
DistributedType,
GradClipAlgorithmType,
LightningEnum,
)
from pytorch_lightning.utilities.imports import ( # noqa: F401
_APEX_AVAILABLE,
_BOLTS_AVAILABLE,
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,16 @@ class DeviceType(LightningEnum):
CPU = 'CPU'
GPU = 'GPU'
TPU = 'TPU'


class GradClipAlgorithmType(LightningEnum):
""" Define gradient_clip_algorithm types - training-tricks.
NORM type means "clipping gradients by norm". This computed over all model parameters together.
VALUE tpye means "clipping gradients by value". This will clip the gradient value for each parameter.

References:
clip_by_norm: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm
clip_by_value: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value
"""
VALUE = 'value'
NORM = 'norm'
13 changes: 13 additions & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import shlex
import subprocess
import sys
from copy import deepcopy

import numpy as np
import pytest
Expand Down Expand Up @@ -58,6 +59,13 @@ def _run_horovod(trainer_options, on_gpu=False):
assert exit_code == 0


def _run_horovod_clip_grad_by_value(trainer_options, on_gpu=False):
# clip_grad_by_value test
trainer_options_clip_grad_val = deepcopy(trainer_options)
trainer_options_clip_grad_val.update({'gradient_clip_algorithm': 'value'})
_run_horovod(trainer_options_clip_grad_val, on_gpu)


@RunIf(skip_windows=True)
def test_horovod_cpu(tmpdir):
"""Test Horovod running multi-process on CPU."""
Expand All @@ -73,6 +81,7 @@ def test_horovod_cpu(tmpdir):
deterministic=True,
)
_run_horovod(trainer_options)
_run_horovod_clip_grad_by_value(trainer_options)


@RunIf(skip_windows=True)
Expand All @@ -89,6 +98,7 @@ def test_horovod_cpu_implicit(tmpdir):
deterministic=True,
)
_run_horovod(trainer_options)
_run_horovod_clip_grad_by_value(trainer_options)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
Expand All @@ -107,6 +117,7 @@ def test_horovod_multi_gpu(tmpdir):
accelerator='horovod',
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Horovod has a problem with broadcast when using apex?") # todo
Expand All @@ -128,6 +139,7 @@ def test_horovod_apex(tmpdir):
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@pytest.mark.skip(reason="Skip till Horovod fixes integration with Native torch.cuda.amp") # todo
Expand All @@ -149,6 +161,7 @@ def test_horovod_amp(tmpdir):
precision=16,
)
_run_horovod(trainer_options, on_gpu=True)
_run_horovod_clip_grad_by_value(trainer_options, on_gpu=True)


@RunIf(min_gpus=1, skip_windows=True, horovod_nccl=True)
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,26 @@ def test_tpu_grad_norm(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_clip_grad_by_value(tmpdir):
"""Test if clip_gradients by value works on TPU."""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
gradient_clip_val=0.5,
gradient_clip_algorithm='value'
)

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_dataloaders_passed_to_fit(tmpdir):
Expand Down
Loading