diff --git a/CHANGELOG.md b/CHANGELOG.md index c65c87371bda4..c45df04c66d78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417)) +- Added `gradient_clip_algorithm` argument to Trainer for gradient clipping by value ([#6123](https://github.com/PyTorchLightning/pytorch-lightning/pull/6123)). + + - Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470)) diff --git a/docs/source/advanced/training_tricks.rst b/docs/source/advanced/training_tricks.rst index 4f7452c2da1de..c3b232b41c13c 100644 --- a/docs/source/advanced/training_tricks.rst +++ b/docs/source/advanced/training_tricks.rst @@ -26,8 +26,10 @@ The effect is a large effective batch size of size KxN. Gradient Clipping ----------------- -Gradient clipping may be enabled to avoid exploding gradients. Specifically, this will `clip the gradient -norm `_ computed over all model parameters together. +Gradient clipping may be enabled to avoid exploding gradients. By default, this will `clip the gradient norm +`_ computed over all model parameters together. +If ``gradient_clip_algorithm`` option is set to ``value``, which is ``norm`` by default, this will +`clip the gradient value `_ for each parameter instead. .. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` @@ -39,6 +41,10 @@ norm `_ # clip gradients with norm above 0.5 trainer = Trainer(gradient_clip_val=0.5) + # clip gradients with value above 0.5 + # gradient_clip_algorithm types => :class:`~pytorch_lightning.utilities.enums.GradClipAlgorithmType` + trainer = Trainer(gradient_clip_val=0.5, gradient_clip_algorithm='value') + ---------- Stochastic Weight Averaging diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 66bbdc7fc3750..e97450fdbd885 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -24,7 +24,7 @@ from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device -from pytorch_lightning.utilities.enums import AMPType, LightningEnum +from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum if TYPE_CHECKING: from torch.cuda.amp import GradScaler @@ -315,10 +315,14 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt model_ref = self.lightning_module model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + ) -> None: """clips all the optimizer parameters to the given value""" - - self.precision_plugin.clip_gradients(self.model, optimizer, clip_val) + self.precision_plugin.clip_gradients(self.model, optimizer, clip_val, gradient_clip_algorithm) def on_train_epoch_end(self, outputs: Sequence[_STEP_OUTPUT_TYPE]) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index b600eca5e6bc2..4c82307bd6c74 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -30,6 +30,7 @@ class ApexMixedPrecisionPlugin(MixedPrecisionPlugin): """Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)""" def __init__(self, amp_level: str = "O2") -> None: + super().__init__() self.backend = AMPType.APEX self.amp_level = amp_level diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 6a8357229a6e6..32ca52f8873c0 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -16,6 +16,7 @@ import torch from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache @@ -80,7 +81,7 @@ def clip_gradients( model: 'LightningModule', optimizer: 'Optimizer', clip_val: Union[int, float], - norm_type: float = 2.0 + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """ DeepSpeed handles clipping gradients via the training type plugin. diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 6e37c79f2b163..c8f84d4928abf 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -67,6 +67,7 @@ class DoublePrecisionPlugin(PrecisionPlugin): precision: int = 64 def __init__(self) -> None: + super().__init__() self.patches: List[_DoublePrecisionPatch] = [] def connect( diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 3c83945c8a1b7..cf762bdefcb0e 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -30,6 +30,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): def __init__(self) -> None: + super().__init__() if not _NATIVE_AMP_AVAILABLE: raise MisconfigurationException( "You have asked for native AMP but your PyTorch version does not support it." diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 19eb4e0cfb21b..a7f4d7101ad53 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -17,6 +17,7 @@ import torch from pytorch_lightning.plugins.base_plugin import Plugin +from pytorch_lightning.utilities import GradClipAlgorithmType if TYPE_CHECKING: from torch.nn import Module @@ -33,6 +34,13 @@ class PrecisionPlugin(Plugin): EPSILON: float = 1e-6 precision: Union[str, int] = 32 + def __init__(self) -> None: + super().__init__() + self.clip_grad_funcs = { + GradClipAlgorithmType.VALUE: self.clip_grad_by_value, + GradClipAlgorithmType.NORM: self.clip_grad_by_norm, + } + def master_params(self, optimizer: 'Optimizer') -> Generator[torch.Tensor, None, None]: """The master params of the model. Returns the plain model params here. Maybe different in other precision plugins. @@ -103,20 +111,29 @@ def clip_gradients( model: 'LightningModule', optimizer: 'Optimizer', clip_val: Union[int, float], - norm_type: float = 2.0 + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: - """Clips the gradients to a specific value""" + """Clips the gradients""" if clip_val is None: return - grad_clip_val = float(clip_val) - - if grad_clip_val <= 0: + clip_val = float(clip_val) + if clip_val <= 0: return + clip_grad_func = self.clip_grad_funcs[gradient_clip_algorithm] + clip_grad_func(optimizer, clip_val) # type: ignore + + def clip_grad_by_value(self, optimizer: 'Optimizer', clip_val: Union[int, float]) -> None: + """Clip gradients by value""" parameters = list(self.master_params(optimizer)) + torch.nn.utils.clip_grad_value_(parameters, clip_value=clip_val) - max_norm = grad_clip_val + def clip_grad_by_norm(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + """Clip gradients by norm""" + # TODO: separate TPU case from here + parameters = list(self.master_params(optimizer)) + max_norm = clip_val if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index b9326b665c00d..03b2117191d82 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -23,8 +23,6 @@ if TYPE_CHECKING: from torch.optim import Optimizer - from pytorch_lightning.core import LightningModule - class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training @@ -34,15 +32,11 @@ def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients( + def clip_grad_by_norm( self, - model: 'LightningModule', optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 ) -> None: - if clip_val <= 0: - return - optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index dd7aad8cd6d88..899ffbf56e8fd 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -23,6 +24,7 @@ def __init__(self, trainer): def on_trainer_init( self, gradient_clip_val, + gradient_clip_algorithm, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, @@ -32,7 +34,12 @@ def on_trainer_init( self.trainer.terminate_on_nan = terminate_on_nan # gradient clipping + if gradient_clip_algorithm not in list(GradClipAlgorithmType): + raise MisconfigurationException( + f"gradient_clip_algorithm should be in {list(GradClipAlgorithmType)}" + ) self.trainer.gradient_clip_val = gradient_clip_val + self.trainer.gradient_clip_algorithm = gradient_clip_algorithm # gradient norm tracking if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf': diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 27dcd6fe9aa0d..340b0a81d7ab8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -91,6 +91,7 @@ def __init__( callbacks: Optional[Union[List[Callback], Callback]] = None, default_root_dir: Optional[str] = None, gradient_clip_val: float = 0, + gradient_clip_algorithm: str = 'norm', process_position: int = 0, num_nodes: int = 1, num_processes: int = 1, @@ -197,6 +198,8 @@ def __init__( gradient_clip_val: 0 means don't clip. + gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm' + limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches) limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches) @@ -347,7 +350,12 @@ def __init__( # init training tricks self.training_tricks_connector.on_trainer_init( - gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps, terminate_on_nan + gradient_clip_val, + gradient_clip_algorithm, + track_grad_norm, + accumulate_grad_batches, + truncated_bptt_steps, + terminate_on_nan, ) self.train_loop.on_trainer_init( max_epochs, diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 28cb05bc06f2d..a6d549e3827bc 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -23,7 +23,13 @@ rank_zero_only, rank_zero_warn, ) -from pytorch_lightning.utilities.enums import AMPType, DeviceType, DistributedType, LightningEnum # noqa: F401 +from pytorch_lightning.utilities.enums import ( # noqa: F401 + AMPType, + DeviceType, + DistributedType, + GradClipAlgorithmType, + LightningEnum, +) from pytorch_lightning.utilities.imports import ( # noqa: F401 _APEX_AVAILABLE, _BOLTS_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 169481fa63e67..391d87abfb6cc 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -94,3 +94,16 @@ class DeviceType(LightningEnum): CPU = 'CPU' GPU = 'GPU' TPU = 'TPU' + + +class GradClipAlgorithmType(LightningEnum): + """ Define gradient_clip_algorithm types - training-tricks. + NORM type means "clipping gradients by norm". This computed over all model parameters together. + VALUE tpye means "clipping gradients by value". This will clip the gradient value for each parameter. + + References: + clip_by_norm: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm + clip_by_value: https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_value + """ + VALUE = 'value' + NORM = 'norm' diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 5b4a700babd1d..49e4b04933eab 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -50,7 +50,7 @@ def _run_horovod(trainer_options, on_gpu=False): trainer_options.update(gpus=1 if on_gpu else None) tutils.reset_seed() # todo: Find why coverage breaks CI. - # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # noqa E265 + # append = '-a' if '.coverage' in os.listdir(_PROJECT_ROOT) else '' # str(num_processes), sys.executable, '-m', 'coverage', 'run', '--source', 'pytorch_lightning', append, # noqa E265 cmdline = [ 'horovodrun', '-np', @@ -80,6 +80,24 @@ def test_horovod_cpu(tmpdir): _run_horovod(trainer_options) +@RunIf(skip_windows=True, horovod=True) +def test_horovod_cpu_clip_grad_by_value(tmpdir): + """Test Horovod running multi-process on CPU.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + gradient_clip_algorithm='value', + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + accelerator='horovod', + deterministic=True, + ) + _run_horovod(trainer_options) + + @RunIf(skip_windows=True, horovod=True) def test_horovod_cpu_implicit(tmpdir): """Test Horovod without specifying a backend, inferring from env set by `horovodrun`.""" @@ -114,6 +132,25 @@ def test_horovod_multi_gpu(tmpdir): _run_horovod(trainer_options, on_gpu=True) +@RunIf(min_gpus=2, skip_windows=True, horovod_nccl=True) +def test_horovod_multi_gpu_grad_by_value(tmpdir): + """Test Horovod with multi-GPU support.""" + trainer_options = dict( + default_root_dir=str(tmpdir), + weights_save_path=str(tmpdir), + gradient_clip_val=1.0, + gradient_clip_algorithm='value', + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.4, + limit_val_batches=0.2, + gpus=2, + deterministic=True, + accelerator='horovod', + ) + _run_horovod(trainer_options, on_gpu=True) + + # https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994 # Check with (tgaddair) on Horovod issues if this feature is needed @pytest.mark.skip(reason="Horovod currently doesn't work with Apex") # todo diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index b2ed0db87d8d5..e6fb0c96ef403 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -219,6 +219,26 @@ def test_tpu_grad_norm(tmpdir): tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) +@RunIf(tpu=True) +@pl_multi_process_test +def test_tpu_clip_grad_by_value(tmpdir): + """Test if clip_gradients by value works on TPU.""" + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=4, + tpu_cores=1, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=0.5, + gradient_clip_algorithm='value' + ) + + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + @RunIf(tpu=True) @pl_multi_process_test def test_dataloaders_passed_to_fit(tmpdir): diff --git a/tests/plugins/test_precision_plugin.py b/tests/plugins/test_precision_plugin.py deleted file mode 100644 index fc00f22a6413e..0000000000000 --- a/tests/plugins/test_precision_plugin.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from inspect import signature - -from pytorch_lightning.plugins.precision import PrecisionPlugin - - -def test_precision_clip_gradients_signature(): - - expected_params_list = ['self', 'model', 'optimizer', 'clip_val', 'norm_type'] - - params = signature(PrecisionPlugin.clip_gradients).parameters - params_list = [param.name for param in params.values()] - - assert params_list == expected_params_list diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0c0488009d5af..d3861e211203e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -875,6 +875,46 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde trainer.fit(model) +def test_gradient_clipping_by_value(tmpdir): + """ + Test gradient clipping by value + """ + tutils.reset_seed() + + model = BoringModel() + + grad_clip_val = 0.0001 + trainer = Trainer( + max_steps=10, + max_epochs=1, + gradient_clip_val=grad_clip_val, + gradient_clip_algorithm='value', + default_root_dir=tmpdir, + ) + + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + parameters = model.parameters() + grad_max_list = [torch.max(p.grad.detach().abs()) for p in parameters] + grad_max = torch.max(torch.stack(grad_max_list)) + assert round(grad_max.item(), 6) <= grad_clip_val, \ + f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + + return ret_val + + trainer.train_loop.training_step_and_backward = training_step_and_backward + # for the test + model.prev_called_batch_idx = 0 + + trainer.fit(model) + + @RunIf(min_gpus=1, amp_native=True) def test_gradient_clipping_fp16(tmpdir): """ @@ -913,6 +953,46 @@ def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hidde trainer.fit(model) +@RunIf(min_gpus=1, amp_native=True) +def test_gradient_clipping_by_value_fp16(tmpdir): + """ + Test gradient clipping by value with fp16 + """ + tutils.reset_seed() + + model = BoringModel() + grad_clip_val = 0.0001 + trainer = Trainer( + max_steps=10, + max_epochs=1, + precision=16, + gpus=1, + gradient_clip_val=grad_clip_val, + gradient_clip_algorithm='value', + default_root_dir=tmpdir, + ) + + trainer.train_loop.old_training_step_and_backward = trainer.train_loop.training_step_and_backward + + def training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens): + """ + wrap the forward step in a closure so second order methods work + """ + # test that gradient is clipped correctly + ret_val = trainer.train_loop.old_training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) + parameters = model.parameters() + grad_max = torch.max(torch.stack([p.grad.detach() for p in parameters])) + assert round(grad_max.item(), 6) <= grad_clip_val, \ + f"Gradient max value {grad_max} > grad_clip_val {grad_clip_val} ." + + return ret_val + + trainer.train_loop.training_step_and_backward = training_step_and_backward + model.prev_called_batch_idx = 0 + + trainer.fit(model) + + def test_gpu_choice(tmpdir): trainer_options = dict(default_root_dir=tmpdir) # Only run if CUDA is available