From c5bf99fab12dbdf2497f99f8d6412d0f3c07f938 Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Thu, 27 Feb 2020 00:48:26 -0500 Subject: [PATCH 1/6] clip --- pytorch_lightning/trainer/training_tricks.py | 24 +++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7fa4059afc3e2..c55454d9431f8 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod import torch +import math from pytorch_lightning.callbacks import GradientAccumulationScheduler @@ -19,9 +20,30 @@ def get_model(self): pass def clip_gradients(self): + # this code is a modification of torch.nn.utils.clip_grad_norm_ + # with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md if self.gradient_clip_val > 0: model = self.get_model() - torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) + parameters = model.parameters() + max_norm = self.gradient_clip_val + norm_type = 2 + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + max_norm = float(max_norm) + norm_type = float(norm_type) + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + device = parameters[0].device + total_norm = torch.zeros([], device=device if parameters else None) + for p in parameters: + param_norm = p.grad.data.norm(norm_type) ** norm_type + total_norm.add_(param_norm) + total_norm = (total_norm ** (1. / norm_type)) + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) + for p in parameters: + p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) def print_nan_gradients(self): model = self.get_model() From 5d14129244b6fa3e9021a6b452b0f4f1cff454f9 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 27 Feb 2020 09:16:59 -0500 Subject: [PATCH 2/6] Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index c55454d9431f8..ee69bf779e0f6 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -25,8 +25,8 @@ def clip_gradients(self): if self.gradient_clip_val > 0: model = self.get_model() parameters = model.parameters() - max_norm = self.gradient_clip_val - norm_type = 2 + max_norm = float(self.gradient_clip_val) + norm_type = float(2.0) if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) From 41ff747e9e53210c692db22373ca79ffb62d107e Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 27 Feb 2020 09:17:13 -0500 Subject: [PATCH 3/6] Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index ee69bf779e0f6..7e9bef6beb8de 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -30,8 +30,6 @@ def clip_gradients(self): if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = list(filter(lambda p: p.grad is not None, parameters)) - max_norm = float(max_norm) - norm_type = float(norm_type) if norm_type == math.inf: total_norm = max(p.grad.data.abs().max() for p in parameters) else: From aa1b50c8c0130c4ef5a6c5d615d9126194d54c4c Mon Sep 17 00:00:00 2001 From: Sasha Date: Thu, 27 Feb 2020 11:10:57 -0500 Subject: [PATCH 4/6] pull out epsilon --- pytorch_lightning/trainer/training_tricks.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 7e9bef6beb8de..e0a7f89df2082 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -6,6 +6,8 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler +EPSILON = 1e-6 + class TrainerTrainingTricksMixin(ABC): @@ -39,7 +41,7 @@ def clip_gradients(self): param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6) + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) From a3ef8f24833575ec157f223e17803c52263b7ece Mon Sep 17 00:00:00 2001 From: Sasha Date: Thu, 27 Feb 2020 11:31:09 -0500 Subject: [PATCH 5/6] add fp16 case --- pytorch_lightning/trainer/training_tricks.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index e0a7f89df2082..1186372313d60 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -7,6 +7,7 @@ from pytorch_lightning.callbacks import GradientAccumulationScheduler EPSILON = 1e-6 +EPSILON_FP16 = 1e-5 class TrainerTrainingTricksMixin(ABC): @@ -41,7 +42,10 @@ def clip_gradients(self): param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) + if self.precision == 16: + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON_FP16) + else: + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device))) From cd8617939973c083301f2c5aa4030fb0241d3a08 Mon Sep 17 00:00:00 2001 From: srush Date: Thu, 27 Feb 2020 11:55:20 -0500 Subject: [PATCH 6/6] Update pytorch_lightning/trainer/training_tricks.py Co-Authored-By: Jirka Borovec --- pytorch_lightning/trainer/training_tricks.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 1186372313d60..6171e487e74d7 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -42,10 +42,8 @@ def clip_gradients(self): param_norm = p.grad.data.norm(norm_type) ** norm_type total_norm.add_(param_norm) total_norm = (total_norm ** (1. / norm_type)) - if self.precision == 16: - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON_FP16) - else: - clip_coef = torch.tensor(max_norm, device=device) / (total_norm + EPSILON) + eps = EPSILON_FP16 if self.precision == 16 else EPSILON + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) for p in parameters: p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))