-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Mistake in parameters' grad norm tracking (#2012)
* fix grad norm formula * grad-norm tracker test * fixed seed and explicit rtol in grad norm tracking test * a docstring for grad-norms and forced cast to float of norm_type * support for inf-norm * renamed the grad norm test * docs * fixed language in docstring * Apply suggestions from code review Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
- Loading branch information
1 parent
a699003
commit e85a646
Showing
4 changed files
with
148 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,41 @@ | ||
""" | ||
Module to describe gradients | ||
""" | ||
from typing import Dict | ||
from typing import Dict, Union | ||
|
||
from torch import nn | ||
import torch | ||
|
||
|
||
class GradInformation(nn.Module): | ||
class GradInformation(torch.nn.Module): | ||
|
||
def grad_norm(self, norm_type: float) -> Dict[str, int]: | ||
results = {} | ||
total_norm = 0 | ||
def grad_norm(self, norm_type: Union[float, int, str]) -> Dict[str, float]: | ||
"""Compute each parameter's gradient's norm and their overall norm. | ||
The overall norm is computed over all gradients together, as if they | ||
were concatenated into a single vector. | ||
Args: | ||
norm_type: The type of the used p-norm, cast to float if necessary. | ||
Can be ``'inf'`` for infinity norm. | ||
Return: | ||
norms: The dictionary of p-norms of each parameter's gradient and | ||
a special entry for the total p-norm of the gradients viewed | ||
as a single vector. | ||
""" | ||
norm_type = float(norm_type) | ||
|
||
norms, all_norms = {}, [] | ||
for name, p in self.named_parameters(): | ||
if p.requires_grad: | ||
try: | ||
param_norm = p.grad.data.norm(norm_type) | ||
total_norm += param_norm ** norm_type | ||
norm = param_norm ** (1 / norm_type) | ||
|
||
grad = round(norm.data.cpu().numpy().flatten()[0], 3) | ||
results['grad_{}_norm_{}'.format(norm_type, name)] = grad | ||
except Exception: | ||
# this param had no grad | ||
pass | ||
|
||
total_norm = total_norm ** (1. / norm_type) | ||
grad = round(total_norm.data.cpu().numpy().flatten()[0], 3) | ||
results['grad_{}_norm_total'.format(norm_type)] = grad | ||
return results | ||
if p.grad is None: | ||
continue | ||
|
||
param_norm = float(p.grad.data.norm(norm_type)) | ||
norms[f'grad_{norm_type}_norm_{name}'] = round(param_norm, 3) | ||
|
||
all_norms.append(param_norm) | ||
|
||
total_norm = float(torch.tensor(all_norms).norm(norm_type)) | ||
norms[f'grad_{norm_type}_norm_total'] = round(total_norm, 3) | ||
|
||
return norms |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import torch | ||
import pytest | ||
import numpy as np | ||
|
||
from pytorch_lightning import Trainer, seed_everything | ||
|
||
from pytorch_lightning.loggers import LightningLoggerBase | ||
from pytorch_lightning.utilities import rank_zero_only | ||
|
||
from tests.base import EvalModelTemplate | ||
from tests.base.utils import reset_seed | ||
|
||
|
||
class OnlyMetricsListLogger(LightningLoggerBase): | ||
def __init__(self): | ||
super().__init__() | ||
self.metrics = [] | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics, step): | ||
self.metrics.append(metrics) | ||
|
||
@property | ||
def experiment(self): | ||
return 'test' | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params): | ||
pass | ||
|
||
@rank_zero_only | ||
def finalize(self, status): | ||
pass | ||
|
||
@property | ||
def name(self): | ||
return 'name' | ||
|
||
@property | ||
def version(self): | ||
return '1' | ||
|
||
|
||
class ModelWithManualGradTracker(EvalModelTemplate): | ||
def __init__(self, norm_type, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.stored_grad_norms, self.norm_type = [], float(norm_type) | ||
|
||
# validation spoils logger's metrics with `val_loss` records | ||
validation_step = None | ||
val_dataloader = None | ||
|
||
def training_step(self, batch, batch_idx, optimizer_idx=None): | ||
# just return a loss, no log or progress bar meta | ||
x, y = batch | ||
loss_val = self.loss(y, self(x.flatten(1, -1))) | ||
return {'loss': loss_val} | ||
|
||
def on_after_backward(self): | ||
out, norms = {}, [] | ||
prefix = f'grad_{self.norm_type}_norm_' | ||
for name, p in self.named_parameters(): | ||
if p.grad is None: | ||
continue | ||
|
||
# `np.linalg.norm` implementation likely uses fp64 intermediates | ||
flat = p.grad.data.cpu().numpy().ravel() | ||
norm = np.linalg.norm(flat, self.norm_type) | ||
norms.append(norm) | ||
|
||
out[prefix + name] = round(norm, 3) | ||
|
||
# handle total norm | ||
norm = np.linalg.norm(norms, self.norm_type) | ||
out[prefix + 'total'] = round(norm, 3) | ||
self.stored_grad_norms.append(out) | ||
|
||
|
||
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) | ||
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): | ||
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above | ||
|
||
reset_seed() | ||
|
||
# use a custom grad tracking module and a list logger | ||
model = ModelWithManualGradTracker(norm_type) | ||
logger = OnlyMetricsListLogger() | ||
|
||
trainer = Trainer( | ||
max_epochs=3, | ||
logger=logger, | ||
track_grad_norm=norm_type, | ||
row_log_interval=1, # request grad_norms every batch | ||
) | ||
result = trainer.fit(model) | ||
|
||
assert result == 1, "Training failed" | ||
assert len(logger.metrics) == len(model.stored_grad_norms) | ||
|
||
# compare the logged metrics against tracked norms on `.backward` | ||
for mod, log in zip(model.stored_grad_norms, logger.metrics): | ||
common = mod.keys() & log.keys() | ||
|
||
log, mod = [log[k] for k in common], [mod[k] for k in common] | ||
|
||
assert np.allclose(log, mod, rtol=rtol) |