Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU Gradient clipping #959

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@
from typing import Callable

import sys
import gc
from abc import ABC, abstractmethod

import torch
Expand All @@ -134,6 +135,7 @@
from pytorch_lightning.utilities.debugging import MisconfigurationException

try:
import torch_xla
import torch_xla.distributed.parallel_loader as xla_pl
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -375,9 +377,17 @@ def run_evaluation(self, test_mode: bool = False):
self.val_progress_bar.close()

# model checkpointing
gc.collect()
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test_mode:
print("checkpointing")
self.checkpoint_callback.on_validation_end(self, self.get_model())
print("done checkpointing")

# wait for all models to checkpoint
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.TrainerEvaluationLoopMixin.run_evaluation")

# Validation/Test end callbacks
if test_mode:
self.on_test_end()
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@

import torch
import torch.distributed as dist

import gc
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
Expand Down Expand Up @@ -287,7 +287,9 @@ def _atomic_save(self, checkpoint, filepath):
This points to the file that the checkpoint will be stored in.
"""
tmp_path = str(filepath) + ".part"
gc.collect()
torch.save(checkpoint, tmp_path)
gc.collect()
os.replace(tmp_path, filepath)

def save_checkpoint(self, filepath):
Expand Down
24 changes: 23 additions & 1 deletion pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod

import torch
import math

from pytorch_lightning.callbacks import GradientAccumulationScheduler

Expand All @@ -19,9 +20,30 @@ def get_model(self):
pass

def clip_gradients(self):
# this code is a modification of torch.nn.utils.clip_grad_norm_
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
if self.gradient_clip_val > 0:
model = self.get_model()
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
parameters = model.parameters()
max_norm = self.gradient_clip_val
norm_type = 2
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
if norm_type == math.inf:
total_norm = max(p.grad.data.abs().max() for p in parameters)
else:
device = parameters[0].device
total_norm = torch.zeros([], device=device if parameters else None)
for p in parameters:
param_norm = p.grad.data.norm(norm_type) ** norm_type
total_norm.add_(param_norm)
total_norm = (total_norm ** (1. / norm_type))
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + 1e-6)
for p in parameters:
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))

def print_nan_gradients(self):
model = self.get_model()
Expand Down