Skip to content

Commit

Permalink
Use torch.nn.utils.clip_grad_norm_ and add clip_grad_by_value sup…
Browse files Browse the repository at this point in the history
…port for TPU (#7025)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
  • Loading branch information
3 people authored May 7, 2021
1 parent 9ba76ce commit 8208c33
Show file tree
Hide file tree
Showing 9 changed files with 34 additions and 65 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


### Changed


- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


### Deprecated


Expand Down
19 changes: 11 additions & 8 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will clip the gradient norm by calling
:func:`torch.nn.utils.clip_grad_norm_` computed over all model parameters together.
If the Trainer's ``gradient_clip_algorithm`` is set to ``'value'`` (``'norm'`` by default), this will use instead
:func:`torch.nn.utils.clip_grad_norm_` for each parameter instead.

.. note::
If using mixed precision, the ``gradient_clip_val`` does not need to be changed as the gradients are unscaled
before applying the clipping function.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -38,11 +42,10 @@ If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by
# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)
# clip gradients' global norm to <=0.5
trainer = Trainer(gradient_clip_val=0.5) # gradient_clip_algorithm='norm' by default

# clip gradients with value above 0.5
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
# clip gradients' maximum magnitude to <=0.5
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')

----------
Expand Down
32 changes: 7 additions & 25 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Callable, Union
from typing import Any, Callable

from torch.optim import Optimizer

Expand All @@ -21,15 +21,16 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_5, _XLA_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

# rename to mock in a test
xla_clip_grad_norm_ = clip_grad_norm_
# the patch is not required after 1.5.0
if _TORCH_GREATER_EQUAL_1_5:
from torch_xla._patched_functions import _apply_patches
_apply_patches() # patches `torch.nn.utils.clip_grad_norm_`


class TPUAccelerator(Accelerator):
Expand All @@ -43,8 +44,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
"Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
"amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
)

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
Expand All @@ -59,21 +59,3 @@ def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[float, int],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

parameters = self.model.parameters()
norm_type = 2.0

xla_clip_grad_norm_(parameters, grad_clip_val, norm_type)
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@
class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON: float = 1e-5
backend: 'AMPType'
precision: Union[str, int] = "mixed"
30 changes: 5 additions & 25 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
Expand All @@ -28,10 +27,9 @@
class PrecisionPlugin(Plugin):
"""
Base class for all plugins handling the precision-specific parts of the training.
The static classattributes EPSILON and precision must be overwritten in child-classes and their
default values reflect fp32 training.
The class attribute precision must be overwritten in child classes.
The default value reflects fp32 training.
"""
EPSILON: float = 1e-6
precision: Union[str, int] = 32

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
Expand Down Expand Up @@ -118,32 +116,14 @@ def clip_gradients(
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
# TODO: there should be a mechanism to set `norm_type`
self.clip_grad_by_norm(optimizer, clip_val, eps=self.EPSILON)
self.clip_grad_by_norm(optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(
self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6
) -> None:
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm"""
parameters = self.master_params(optimizer)

# TODO: replace this with torch.nn.clip_grad_norm_
parameters = list(filter(lambda p: p.grad is not None, parameters))
device = parameters[0].device

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(clip_val, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients

EPSILON = 1e-6
EPSILON_FP16 = 1e-5
log = logging.getLogger(__name__)


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_5,
_TORCH_GREATER_EQUAL_1_6,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _compare_version(package: str, op, version) -> bool:
_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_5 = _compare_version("torch", operator.ge, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
Expand Down
7 changes: 3 additions & 4 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def test_tpu_grad_norm(tmpdir):
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_clip_grad_by_value(tmpdir):
"""Test if clip_gradients by value works on TPU. (It should not.)"""
"""Test if clip_gradients by value works on TPU"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
Expand All @@ -237,8 +237,7 @@ def test_tpu_clip_grad_by_value(tmpdir):
)

model = BoringModel()
with pytest.raises(AssertionError):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
Expand Down Expand Up @@ -382,7 +381,7 @@ def test_reduce(rank):
@RunIf(tpu=True)
@pl_multi_process_test
@pytest.mark.parametrize("clip_val", [10])
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
@mock.patch("torch.nn.utils.clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
Expand Down

0 comments on commit 8208c33

Please sign in to comment.