Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch.nn.utils.clip_grad_norm_ and add clip_grad_by_value support for TPU #7025

Merged
merged 19 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))


- Added `clip_grad_by_value` support for TPUs ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


### Changed

Expand Down Expand Up @@ -135,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `pl.seed_everyting` will now also set the seed on the `DistributedSampler` ([#7024](https://github.com/PyTorchLightning/pytorch-lightning/pull/7024))


- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


carmocca marked this conversation as resolved.
Show resolved Hide resolved
### Deprecated

- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
Expand Down
19 changes: 11 additions & 8 deletions docs/source/advanced/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ The effect is a large effective batch size of size KxN.

Gradient Clipping
-----------------
Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm
<https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_ computed over all model parameters together.
If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will
`clip the gradient value <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value_>`_ for each parameter instead.
Gradient clipping may be enabled to avoid exploding gradients. By default, this will clip the gradient norm by calling
:func:`torch.nn.utils.clip_grad_norm_` computed over all model parameters together.
If the Trainer's ``gradient_clip_algorithm`` is set to ``'value'`` (``'norm'`` by default), this will use instead
:func:`torch.nn.utils.clip_grad_norm_` for each parameter instead.

.. note::
If using mixed precision, the ``gradient_clip_val`` does not need to be changed as the gradients are unscaled
before applying the clipping function.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -38,11 +42,10 @@ If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by
# DEFAULT (ie: don't clip)
trainer = Trainer(gradient_clip_val=0)

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)
# clip gradients' global norm to <=0.5
trainer = Trainer(gradient_clip_val=0.5) # gradient_clip_algorithm='norm' by default

# clip gradients with value above 0.5
# gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType`
# clip gradients' maximum magnitude to <=0.5
trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value')

----------
Expand Down
32 changes: 7 additions & 25 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Union
from typing import Any, Callable

from torch.optim import Optimizer

Expand All @@ -20,15 +20,16 @@
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.utilities import _XLA_AVAILABLE, GradClipAlgorithmType
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_5, _XLA_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# rename to mock in a test
xla_clip_grad_norm_ = clip_grad_norm_
# clip_grad_norm_ was updated to not require this patch in 1.5.0
if _TORCH_GREATER_EQUAL_1_5:
from torch_xla._patched_functions import _apply_patches
_apply_patches() # patches torch.nn.utils.clip_grad_norm_


class TPUAccelerator(Accelerator):
Expand All @@ -42,8 +43,7 @@ def setup(self, trainer: 'pl.Trainer', model: 'pl.LightningModule') -> None:
"""
if isinstance(self.precision_plugin, MixedPrecisionPlugin):
raise MisconfigurationException(
"amp + tpu is not supported. "
"Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
"amp + tpu is not supported. Only bfloats are supported on TPU. Consider using TPUHalfPrecisionPlugin"
)

if not isinstance(self.training_type_plugin, (SingleTPUPlugin, TPUSpawnPlugin)):
Expand All @@ -54,21 +54,3 @@ def run_optimizer_step(
self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
) -> None:
xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs})

def clip_gradients(
self,
optimizer: Optimizer,
clip_val: Union[float, int],
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
assert gradient_clip_algorithm == GradClipAlgorithmType.NORM, \
"Only NORM gradient clipping is supported on TPU for now"

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

parameters = self.model.parameters()
norm_type = 2.0

xla_clip_grad_norm_(parameters, grad_clip_val, norm_type)
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/precision/mixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,5 @@
class MixedPrecisionPlugin(PrecisionPlugin):
"""Base Class for mixed precision"""

EPSILON: float = 1e-5
backend: 'AMPType'
precision: Union[str, int] = "mixed"
30 changes: 5 additions & 25 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Callable, List, Tuple, Union

import torch
Expand All @@ -28,10 +27,9 @@
class PrecisionPlugin(Plugin):
"""
Base class for all plugins handling the precision-specific parts of the training.
The static classattributes EPSILON and precision must be overwritten in child-classes and their
default values reflect fp32 training.
The class attribute precision must be overwritten in child classes.
The default value reflects fp32 training.
"""
EPSILON: float = 1e-6
precision: Union[str, int] = 32

def master_params(self, optimizer: Optimizer) -> _PARAMETERS:
Expand Down Expand Up @@ -117,32 +115,14 @@ def clip_gradients(
self.clip_grad_by_value(optimizer, clip_val)
elif gradient_clip_algorithm == GradClipAlgorithmType.NORM:
# TODO: there should be a mechanism to set `norm_type`
self.clip_grad_by_norm(optimizer, clip_val, eps=self.EPSILON)
self.clip_grad_by_norm(optimizer, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by value"""
parameters = self.master_params(optimizer)
torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val)

def clip_grad_by_norm(
self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = 2.0, eps: float = 1e-6
) -> None:
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
"""Clip gradients by norm"""
parameters = self.master_params(optimizer)

# TODO: replace this with torch.nn.clip_grad_norm_
parameters = list(filter(lambda p: p.grad is not None, parameters))
device = parameters[0].device

if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
out = torch.empty(len(parameters), device=device)
for i, p in enumerate(parameters):
torch.norm(p.grad.data.to(device), norm_type, out=out[i])
total_norm = torch.norm(out, norm_type)

clip_coef = torch.tensor(clip_val, device=device) / (total_norm + eps)
clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef))
for p in parameters:
p.grad.data.mul_(clip_coef.to(p.grad.data.device))
torch.nn.utils.clip_grad_norm_(parameters, clip_val)
2 changes: 0 additions & 2 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters, print_nan_gradients

EPSILON = 1e-6
EPSILON_FP16 = 1e-5
log = logging.getLogger(__name__)


Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_5,
_TORCH_GREATER_EQUAL_1_6,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def _compare_version(package: str, op, version) -> bool:
_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_5 = _compare_version("torch", operator.ge, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
Expand Down
7 changes: 3 additions & 4 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def test_tpu_grad_norm(tmpdir):
@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_clip_grad_by_value(tmpdir):
"""Test if clip_gradients by value works on TPU. (It should not.)"""
"""Test if clip_gradients by value works on TPU"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
Expand All @@ -238,8 +238,7 @@ def test_tpu_clip_grad_by_value(tmpdir):
)

model = BoringModel()
with pytest.raises(AssertionError):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)


@RunIf(tpu=True)
Expand Down Expand Up @@ -383,7 +382,7 @@ def test_reduce(rank):
@RunIf(tpu=True)
@pl_multi_process_test
@pytest.mark.parametrize("clip_val", [10])
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
@mock.patch("torch.nn.utils.clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
Expand Down