Skip to content

Commit

Permalink
Fix gradient clipping (#1438)
Browse files Browse the repository at this point in the history
* Fix gradient clipping

* Relax accuracy constraint
  • Loading branch information
alsrgv committed Apr 10, 2020
1 parent b2707c9 commit 8dd9b80
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def clip_gradients(self):
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
Expand Down
27 changes: 27 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,30 @@ def on_batch_start(self, trainer, pl_module):
assert not trainer.interrupted
trainer.fit(model)
assert trainer.interrupted


def test_gradient_clipping(tmpdir):
"""
Test gradient clipping
"""
tutils.reset_seed()

hparams = tutils.get_default_hparams()
model = LightningTestModel(hparams)

# test that gradient is clipped correctly
def _optimizer_step(*args, **kwargs):
parameters = model.parameters()
grad_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2)
assert (grad_norm - 1.0).abs() < 0.01, "Gradient norm != 1.0: {grad_norm}".format(grad_norm=grad_norm)

trainer = Trainer(max_steps=1,
max_epochs=1,
gradient_clip_val=1.0,
default_save_path=tmpdir)

# for the test
model.optimizer_step = _optimizer_step
model.prev_called_batch_idx = 0

trainer.fit(model)

0 comments on commit 8dd9b80

Please sign in to comment.