From 8dd9b80d7a192117783195a748ddce6c33d556f3 Mon Sep 17 00:00:00 2001 From: Alex Sergeev Date: Thu, 9 Apr 2020 18:08:28 -0700 Subject: [PATCH] Fix gradient clipping (#1438) * Fix gradient clipping * Relax accuracy constraint --- pytorch_lightning/trainer/training_tricks.py | 2 +- tests/trainer/test_trainer.py | 27 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 9dd43e193a2be..f5523210efb41 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -41,7 +41,7 @@ def clip_gradients(self): total_norm = torch.zeros([], device=device if parameters else None) for p in parameters: param_norm = p.grad.data.norm(norm_type) ** norm_type - total_norm.add_(param_norm) + total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) eps = EPSILON_FP16 if self.precision == 16 else EPSILON clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c26da3b7b280c..8ae0530af270e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -658,3 +658,30 @@ def on_batch_start(self, trainer, pl_module): assert not trainer.interrupted trainer.fit(model) assert trainer.interrupted + + +def test_gradient_clipping(tmpdir): + """ + Test gradient clipping + """ + tutils.reset_seed() + + hparams = tutils.get_default_hparams() + model = LightningTestModel(hparams) + + # test that gradient is clipped correctly + def _optimizer_step(*args, **kwargs): + parameters = model.parameters() + grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2) + assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm) + + trainer = Trainer(max_steps=1, + max_epochs=1, + gradient_clip_val=1.0, + default_save_path=tmpdir) + + # for the test + model.optimizer_step = _optimizer_step + model.prev_called_batch_idx = 0 + + trainer.fit(model)