Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU gradient clipping. #963

Merged
merged 6 commits into from
Feb 27, 2020
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod

import torch
import math

from pytorch_lightning.callbacks import GradientAccumulationScheduler

Expand All @@ -19,9 +20,30 @@ def get_model(self):
pass

def clip_gradients(self):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val > 0:
model = self.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
parameters = model.parameters()
max_norm = self.gradient_clip_val
norm_type = 2
srush marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
srush marked this conversation as resolved.
Show resolved Hide resolved
if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6)
srush marked this conversation as resolved.
Show resolved Hide resolved
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

def print_nan_gradients(self):
model = self.get_model()
Expand Down