Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer(gradient_clip_algorithm='value'|'norm') #6123

Merged
merged 59 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
09ea112
add changelog
dhkim0225 Feb 22, 2021
c0e8064
add clip by value
dhkim0225 Feb 22, 2021
ca8e6fd
fix bug in training tricks.rst
dhkim0225 Feb 22, 2021
87f12c1
fix bug in trainer.rst
dhkim0225 Feb 22, 2021
8e43b8a
Update trainer.rst
dhkim0225 Feb 22, 2021
2bb9924
Update trainer.rst
dhkim0225 Feb 22, 2021
5b83f0d
Update CHANGELOG.md
dhkim0225 Feb 23, 2021
caafdf2
Update pytorch_lightning/plugins/precision/deepspeed_precision.py
dhkim0225 Feb 23, 2021
ca774b6
Update pytorch_lightning/utilities/enums.py
dhkim0225 Feb 23, 2021
5a741e2
yapf formatting
dhkim0225 Feb 23, 2021
0568e3a
update training tricks
dhkim0225 Feb 23, 2021
1a4e79e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Feb 26, 2021
4a813c1
Merge branch 'master' into feat/clip_grad_by_value
tchaton Feb 26, 2021
2f5cb3e
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
f4275a2
update based on comment
dhkim0225 Mar 2, 2021
e92ec69
update based on comment
dhkim0225 Mar 2, 2021
ac701ce
Update pytorch_lightning/trainer/trainer.py
dhkim0225 Mar 2, 2021
bc20fa4
update based on comment
dhkim0225 Mar 2, 2021
b842210
Merge branch 'feat/clip_grad_by_value' of https://github.com/dhkim022…
dhkim0225 Mar 2, 2021
5ec2ebd
pep8
dhkim0225 Mar 2, 2021
d37fbbc
mypy
dhkim0225 Mar 2, 2021
952c778
mypy
dhkim0225 Mar 2, 2021
b8fdbe1
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 2, 2021
c4cccf0
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 3, 2021
6bd4793
Update docs/source/advanced/training_tricks.rst
dhkim0225 Mar 4, 2021
3aeba85
Update sharded_native_amp.py
dhkim0225 Mar 4, 2021
902a33c
Update test_sharded_parity.py
dhkim0225 Mar 4, 2021
7467616
update test codes
dhkim0225 Mar 4, 2021
5463830
Update test_tpu.py
dhkim0225 Mar 4, 2021
2e933d4
Update pytorch_lightning/trainer/connectors/training_trick_connector.py
dhkim0225 Mar 4, 2021
b1e26e6
Update test_trainer.py
dhkim0225 Mar 4, 2021
cedf5f6
Update enums.py
dhkim0225 Mar 4, 2021
f5bb45d
Update enums.py
dhkim0225 Mar 4, 2021
42fc5f6
Merge branch 'master' into feat/clip_grad_by_value
Borda Mar 4, 2021
e55b90c
Merge branch 'master' into feat/clip_grad_by_value
dhkim0225 Mar 5, 2021
308ce38
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 23, 2021
903f2e2
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
28c948a
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
177a1c9
add clip_grad horovod cpu test
dhkim0225 Mar 25, 2021
fc23845
use subprocess check_call
dhkim0225 Mar 25, 2021
d99a650
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 25, 2021
f80aa8d
change order of horovod tests
dhkim0225 Mar 25, 2021
fb895b6
set max_epochs 2 in horovod test
dhkim0225 Mar 25, 2021
caa0bbf
remove clip_grad_val test from horovod-cpu
dhkim0225 Mar 25, 2021
f1f9015
remove "type: ignore"
dhkim0225 Mar 25, 2021
5dfe5ef
divide clip grad val test in horovod
dhkim0225 Mar 25, 2021
50a6c74
update based on comments
dhkim0225 Mar 25, 2021
c337b12
add super-class initialization to precision plugins.
dhkim0225 Mar 25, 2021
f7a4fda
bugfix
dhkim0225 Mar 25, 2021
48c3dd8
bugfix
dhkim0225 Mar 25, 2021
e7e3b47
revert some changes
dhkim0225 Mar 26, 2021
2997536
revert some changes
dhkim0225 Mar 26, 2021
fb34e84
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
dhkim0225 Mar 26, 2021
8e665ec
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 27, 2021
9575774
Update tests/models/test_horovod.py
carmocca Mar 27, 2021
7c16f6a
Merge branch 'master' into feat/clip_grad_by_value
carmocca Mar 29, 2021
fec189a
merge master
dhkim0225 Apr 6, 2021
1e80304
merge master
dhkim0225 Apr 6, 2021
4d5e05f
Delete signature test
carmocca Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))


- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)).


- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))


Expand Down
10 changes: 8 additions & 2 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient
norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -39,6 +41,10 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

# clip gradients with value above 0.5
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')
Borda marked this conversation as resolved.
Show resolved Hide resolved

----------

Stochastic Weight Averaging
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum

if TYPE_CHECKING:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -315,10 +315,14 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt
model_ref = self.lightning_module
model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[int, float],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
dhkim0225 marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""clips all the optimizer parameters to the given value"""

self.precision_plugin.clip_gradients(self.model, optimizer, clip_val)
self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm)

def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None:
"""Hook to do something on the end of an training epoch
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin):
"""Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)"""

def __init__(self, amp_level: str = "O2") -> None:
super().__init__()
self.backend = AMPType.APEX
self.amp_level = amp_level

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torch

from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -80,7 +81,7 @@ def clip_gradients(
model: 'LightningModule',
optimizer: 'Optimizer',
clip_val: Union[int, float],
norm_type: float = 2.0
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class DoublePrecisionPlugin(PrecisionPlugin):
precision: int = 64

def __init__(self) -> None:
super().__init__()
self.patches: List[_DoublePrecisionPatch] = []

def connect(
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):

def __init__(self) -> None:
super().__init__()
if not _NATIVE_AMP_AVAILABLE:
raise MisconfigurationException(
"You have asked for native AMP but your PyTorch version does not support it."
Expand Down
29 changes: 23 additions & 6 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch

from pytorch_lightning.plugins.base_plugin import Plugin
from pytorch_lightning.utilities import GradClipAlgorithmType

if TYPE_CHECKING:
from torch.nn import Module
Expand All @@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin):
EPSILON: float = 1e-6
precision: Union[str, int] = 32

def __init__(self) -> None:
super().__init__()
self.clip_grad_funcs = {
GradClipAlgorithmType.VALUE: self.clip_grad_by_value,
GradClipAlgorithmType.NORM: self.clip_grad_by_norm,
}

def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]:
"""The master params of the model. Returns the plain model params here.
Maybe different in other precision plugins.
Expand Down Expand Up @@ -103,20 +111,29 @@ def clip_gradients(
model: 'LightningModule',
optimizer: 'Optimizer',
clip_val: Union[int, float],
norm_type: float = 2.0
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
"""Clips the gradients to a specific value"""
"""Clips the gradients"""
if clip_val is None:
return

grad_clip_val = float(clip_val)

if grad_clip_val <= 0:
clip_val = float(clip_val)
if clip_val <= 0:
return

clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm]
clip_grad_func(optimizer, clip_val) # type: ignore

def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make them private ?

Copy link
Contributor

@ananthsub ananthsub Mar 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton what's expected when other precision plugins override them, as the sharded native amp one does?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or protected :]

"""Clip gradients by value"""
parameters = list(self.master_params(optimizer))
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

max_norm = grad_clip_val
def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
"""Clip gradients by norm"""
# TODO: separate TPU case from here
parameters = list(self.master_params(optimizer))
max_norm = clip_val

if isinstance(parameters, torch.Tensor):
parameters = [parameters]
Expand Down
8 changes: 1 addition & 7 deletions pytorch_lightning/plugins/precision/sharded_native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
if TYPE_CHECKING:
from torch.optim import Optimizer

from pytorch_lightning.core import LightningModule


class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Mixed Precision for Sharded Training
Expand All @@ -34,15 +32,11 @@ def __init__(self) -> None:
super().__init__()
self.scaler = ShardedGradScaler()

def clip_gradients(
def clip_grad_by_norm(
self,
model: 'LightningModule',
optimizer: 'Optimizer',
clip_val: Union[int, float],
norm_type: float = 2.0
) -> None:
if clip_val <= 0:
return

optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException


Expand All @@ -23,6 +24,7 @@ def __init__(self, trainer):
def on_trainer_init(
self,
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
Expand All @@ -32,7 +34,12 @@ def on_trainer_init(
self.trainer.terminate_on_nan = terminate_on_nan

# gradient clipping
if gradient_clip_algorithm not in list(GradClipAlgorithmType):
raise MisconfigurationException(
f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}"
)
self.trainer.gradient_clip_val = gradient_clip_val
self.trainer.gradient_clip_algorithm = gradient_clip_algorithm

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
gradient_clip_algorithm: str = 'norm',
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
Expand Down Expand Up @@ -197,6 +198,8 @@ def __init__(

gradient_clip_val: 0 means don't clip.

gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm'

limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches)

limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches)
Expand Down Expand Up @@ -347,7 +350,12 @@ def __init__(

# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan
gradient_clip_val,
gradient_clip_algorithm,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan,
)
self.train_loop.on_trainer_init(
max_epochs,
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
rank_zero_only,
rank_zero_warn,
)
from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401
from pytorch_lightning.utilities.enums import ( # noqa: F401
AMPType,
DeviceType,
DistributedType,
GradClipAlgorithmType,
LightningEnum,
)
from pytorch_lightning.utilities.imports import ( # noqa: F401
_APEX_AVAILABLE,
_BOLTS_AVAILABLE,
Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/utilities/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,16 @@ class DeviceType(LightningEnum):
CPU = 'CPU'
GPU = 'GPU'
TPU = 'TPU'


class GradClipAlgorithmType(LightningEnum):
""" Define gradient_clip_algorithm types - training-tricks.
NORM type means "clipping gradients by norm". This computed over all model parameters together.
VALUE tpye means "clipping gradients by value". This will clip the gradient value for each parameter.

References:
clip_by_norm: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm
clip_by_value: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value
"""
VALUE = 'value'
NORM = 'norm'
39 changes: 38 additions & 1 deletion tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _run_horovod(trainer_options, on_gpu=False):
trainer_options.update(gpus=1 if on_gpu else None)
tutils.reset_seed()
# todo: Find why coverage breaks CI.
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265
# append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else ''
# str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265
cmdline = [
'horovodrun', '-np',
Expand Down Expand Up @@ -80,6 +80,24 @@ def test_horovod_cpu(tmpdir):
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True)
def test_horovod_cpu_clip_grad_by_value(tmpdir):
"""Test Horovod running multi-process on CPU."""
trainer_options = dict(
default_root_dir=str(tmpdir),
weights_save_path=str(tmpdir),
gradient_clip_val=1.0,
gradient_clip_algorithm='value',
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
accelerator='horovod',
deterministic=True,
)
_run_horovod(trainer_options)


@RunIf(skip_windows=True, horovod=True)
def test_horovod_cpu_implicit(tmpdir):
"""Test Horovod without specifying a backend, inferring from env set by `horovodrun`."""
Expand Down Expand Up @@ -114,6 +132,25 @@ def test_horovod_multi_gpu(tmpdir):
_run_horovod(trainer_options, on_gpu=True)


@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True)
def test_horovod_multi_gpu_grad_by_value(tmpdir):
"""Test Horovod with multi-GPU support."""
trainer_options = dict(
default_root_dir=str(tmpdir),
weights_save_path=str(tmpdir),
gradient_clip_val=1.0,
gradient_clip_algorithm='value',
progress_bar_refresh_rate=0,
max_epochs=1,
limit_train_batches=0.4,
limit_val_batches=0.2,
gpus=2,
deterministic=True,
accelerator='horovod',
)
_run_horovod(trainer_options, on_gpu=True)


# https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994
# Check with (tgaddair) on Horovod issues if this feature is needed
@pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,26 @@ def test_tpu_grad_norm(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_clip_grad_by_value(tmpdir):
"""Test if clip_gradients by value works on TPU."""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=4,
tpu_cores=1,
limit_train_batches=4,
limit_val_batches=4,
gradient_clip_val=0.5,
gradient_clip_algorithm='value'
)

model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
@pl_multi_process_test
def test_dataloaders_passed_to_fit(tmpdir):
Expand Down
26 changes: 0 additions & 26 deletions tests/plugins/test_precision_plugin.py

This file was deleted.

Loading