Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix accumulate_grad_batches for last batch #2853

Merged
merged 18 commits into from
Aug 15, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `accumulate_grad_batches` for last batch ([#2853](https://github.com/PyTorchLightning/pytorch-lightning/pull/2853))

- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624))

- Fixed local rank zero casting ([#2640](https://github.com/PyTorchLightning/pytorch-lightning/pull/2640))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
====================

Change gradient accumulation factor according to scheduling.
Trainer also calls ``optimizer.step()`` for the last indivisible step number.

"""

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def forward(self, x):
accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^
Accumulates grads every k batches or as set up in the dict.
Trainer also calls ``optimizer.step()`` for the last indivisible step number.

.. testcode::

Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ def check_checkpoint_callback(self, should_check_val):
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]

def update_train_loop_lr_schedulers(self):
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
# update lr
self.update_learning_rates(interval='step')

Expand Down Expand Up @@ -737,7 +738,8 @@ def sync_horovod(self):

def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
self.global_step += 1
self.total_batch_idx += 1

Expand Down Expand Up @@ -878,7 +880,8 @@ def run_training_batch(self, batch, batch_idx):
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):

# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
Expand Down
72 changes: 65 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import types
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -196,7 +197,7 @@ def test_gradient_accumulation_scheduling(tmpdir, schedule, expected):

trainer = Trainer(
accumulate_grad_batches=schedule,
limit_train_batches=0.8,
limit_train_batches=0.7, # not to be divisible by accumulate_grad_batches on purpose
limit_val_batches=0.8,
max_epochs=4,
default_root_dir=tmpdir,
Expand All @@ -216,8 +217,15 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[0]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]
# separate check for last batch with accumulate 1 step
if expected[0] == 1 and (batch_idx + 1) == trainer.num_training_batches:
assert batch_idx == model.prev_called_batch_idx
elif (batch_idx + 1) == trainer.num_training_batches:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[0] + (batch_idx + 1) % expected[0])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]

elif 1 <= epoch <= 2:
# reset counter when starting epoch
Expand All @@ -227,8 +235,12 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[1]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]
if trainer.num_training_batches == batch_idx + 1:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[1] + (batch_idx + 1) % expected[1])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]

else:
if batch_idx == expected[2] - 1:
Expand All @@ -237,8 +249,12 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[2]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]
if (batch_idx + 1) == trainer.num_training_batches:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[2] + (batch_idx + 1) % expected[2])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]

optimizer.step()

Expand All @@ -252,6 +268,48 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
trainer.fit(model)


@pytest.mark.parametrize(
['accumulate_grad_batches', 'limit_train_batches'],
[
pytest.param({1: 2, 3: 4}, 1.0),
pytest.param({1: 2, 3: 4}, 0.5), # not to be divisible by accumulate_grad_batches on purpose
pytest.param(3, 1.0),
pytest.param(3, 0.8), # not to be divisible by accumulate_grad_batches on purpose
pytest.param(4, 1.0),
pytest.param(4, 0.7), # not to be divisible by accumulate_grad_batches on purpose
],
)
def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches):
""" Verify optimizer.step() applied to last batch while grad accumulation """

class CurrentModel(EvalModelTemplate):
def on_after_backward(self):
self.loss_backward = deepcopy(self.state_dict())

def on_before_zero_grad(self, optimizer):
self.opt_step = self.state_dict()

def on_train_batch_end(self, batch, batch_idx, dataloader_idx):
_exclude_keys = ['num_batches_tracked', 'running_mean', 'running_var']

if (batch_idx + 1) == self.trainer.num_training_batches:
for key in self.loss_backward.keys():
# exclude the check for batch_norm parameters
if not any([k in key for k in _exclude_keys]):
assert not torch.equal(self.loss_backward[key], self.opt_step[key])

model = CurrentModel()

trainer = Trainer(
accumulate_grad_batches=accumulate_grad_batches,
max_epochs=4,
ydcjeff marked this conversation as resolved.
Show resolved Hide resolved
limit_train_batches=limit_train_batches,
default_root_dir=tmpdir
)

trainer.fit(model)


def test_loading_meta_tags(tmpdir):
""" test for backward compatibility to meta_tags.csv """
tutils.reset_seed()
Expand Down