Skip to content

Commit

Permalink
Mistake in parameters' grad norm tracking (#2012)
Browse files Browse the repository at this point in the history
* fix grad norm formula

* grad-norm tracker test

* fixed seed and explicit rtol in grad norm tracking test

* a docstring for grad-norms and forced cast to float of norm_type

* support for inf-norm

* renamed the grad norm test

* docs

* fixed language in docstring

* Apply suggestions from code review

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 2, 2020
1 parent a699003 commit e85a646
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 26 deletions.
55 changes: 33 additions & 22 deletions pytorch_lightning/core/grads.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,41 @@
"""
Module to describe gradients
"""
from typing import Dict
from typing import Dict, Union

from torch import nn
import torch


class GradInformation(nn.Module):
class GradInformation(torch.nn.Module):

def grad_norm(self, norm_type: float) -> Dict[str, int]:
results = {}
total_norm = 0
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]:
"""Compute each parameter's gradient's norm and their overall norm.
The overall norm is computed over all gradients together, as if they
were concatenated into a single vector.
Args:
norm_type: The type of the used p-norm, cast to float if necessary.
Can be ``'inf'`` for infinity norm.
Return:
norms: The dictionary of p-norms of each parameter's gradient and
a special entry for the total p-norm of the gradients viewed
as a single vector.
"""
norm_type = float(norm_type)

norms, all_norms = {}, []
for name, p in self.named_parameters():
if p.requires_grad:
try:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm ** norm_type
norm = param_norm ** (1 / norm_type)

grad = round(norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_{}'.format(norm_type, name)] = grad
except Exception:
# this param had no grad
pass

total_norm = total_norm ** (1. / norm_type)
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
results['grad_{}_norm_total'.format(norm_type)] = grad
return results
if p.grad is None:
continue

param_norm = float(p.grad.data.norm(norm_type))
norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3)

all_norms.append(param_norm)

total_norm = float(torch.tensor(all_norms).norm(norm_type))
norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3)

return norms
11 changes: 8 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
overfit_pct: float = 0.0,
track_grad_norm: int = -1,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(
overfit_pct: How much of training-, validation-, and test dataset to check.
track_grad_norm: -1 no tracking. Otherwise tracks that norm
track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm.
check_val_every_n_epoch: Check val every n train epochs.
Expand Down Expand Up @@ -340,7 +340,12 @@ def __init__(
self.gradient_clip = gradient_clip

self.check_val_every_n_epoch = check_val_every_n_epoch
self.track_grad_norm = track_grad_norm

if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException(
"track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)

self.on_gpu = True if (gpus and torch.cuda.is_available()) else False

# tpu config
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def optimizer_closure():

# track gradient norms when requested
if batch_idx % self.row_log_interval == 0:
if self.track_grad_norm > 0:
if float(self.track_grad_norm) > 0:
model = self.get_model()
grad_norm_dic = model.grad_norm(
self.track_grad_norm)
Expand Down
106 changes: 106 additions & 0 deletions tests/models/test_grad_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import torch
import pytest
import numpy as np

from pytorch_lightning import Trainer, seed_everything

from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only

from tests.base import EvalModelTemplate
from tests.base.utils import reset_seed


class OnlyMetricsListLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.metrics = []

@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics.append(metrics)

@property
def experiment(self):
return 'test'

@rank_zero_only
def log_hyperparams(self, params):
pass

@rank_zero_only
def finalize(self, status):
pass

@property
def name(self):
return 'name'

@property
def version(self):
return '1'


class ModelWithManualGradTracker(EvalModelTemplate):
def __init__(self, norm_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stored_grad_norms, self.norm_type = [], float(norm_type)

# validation spoils logger's metrics with `val_loss` records
validation_step = None
val_dataloader = None

def training_step(self, batch, batch_idx, optimizer_idx=None):
# just return a loss, no log or progress bar meta
x, y = batch
loss_val = self.loss(y, self(x.flatten(1, -1)))
return {'loss': loss_val}

def on_after_backward(self):
out, norms = {}, []
prefix = f'grad_{self.norm_type}_norm_'
for name, p in self.named_parameters():
if p.grad is None:
continue

# `np.linalg.norm` implementation likely uses fp64 intermediates
flat = p.grad.data.cpu().numpy().ravel()
norm = np.linalg.norm(flat, self.norm_type)
norms.append(norm)

out[prefix + name] = round(norm, 3)

# handle total norm
norm = np.linalg.norm(norms, self.norm_type)
out[prefix + 'total'] = round(norm, 3)
self.stored_grad_norms.append(out)


@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above

reset_seed()

# use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type)
logger = OnlyMetricsListLogger()

trainer = Trainer(
max_epochs=3,
logger=logger,
track_grad_norm=norm_type,
row_log_interval=1, # request grad_norms every batch
)
result = trainer.fit(model)

assert result == 1, "Training failed"
assert len(logger.metrics) == len(model.stored_grad_norms)

# compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(model.stored_grad_norms, logger.metrics):
common = mod.keys() & log.keys()

log, mod = [log[k] for k in common], [mod[k] for k in common]

assert np.allclose(log, mod, rtol=rtol)

0 comments on commit e85a646

Please sign in to comment.