Skip to content

Commit

Permalink
Fix accumulate_grad_batches for last batch (#2853)
Browse files Browse the repository at this point in the history
* first attempt

* update changelog

* fix pep8 and tests

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* added new tests

* fixed tests

* Apply suggestions from code review

* used num_training_batches

* fixed pep8

* fixed with is_last_batch suggested by @awaelchli

* fixed with num_training_batches

* fixed with num_training_batches

* cleanup

* fix test and update docs

* fixed for alignment, update docs

* minor changes

* update doc

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people committed Aug 15, 2020
1 parent 5c35db9 commit 73ebd10
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `accumulate_grad_batches` for last batch ([#2853](https://github.com/PyTorchLightning/pytorch-lightning/pull/2853))

- Fixed setup call while testing ([#2624](https://github.com/PyTorchLightning/pytorch-lightning/pull/2624))

- Fixed local rank zero casting ([#2640](https://github.com/PyTorchLightning/pytorch-lightning/pull/2640))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
====================
Change gradient accumulation factor according to scheduling.
Trainer also calls ``optimizer.step()`` for the last indivisible step number.
"""

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def forward(self, x):
accumulate_grad_batches
^^^^^^^^^^^^^^^^^^^^^^^
Accumulates grads every k batches or as set up in the dict.
Trainer also calls ``optimizer.step()`` for the last indivisible step number.
.. testcode::
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ def check_checkpoint_callback(self, should_check_val):
[c.on_validation_end(self, self.get_model()) for c in checkpoint_callbacks]

def update_train_loop_lr_schedulers(self):
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
# update lr
self.update_learning_rates(interval='step')

Expand Down Expand Up @@ -737,7 +738,8 @@ def sync_horovod(self):

def increment_accumulated_grad_global_step(self):
# progress global step according to grads progress
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):
self.global_step += 1
self.total_batch_idx += 1

Expand Down Expand Up @@ -878,7 +880,8 @@ def run_training_batch(self, batch, batch_idx):
# BACKWARD PASS
# ------------------------------
# gradient update with accumulated gradients
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
if ((self.batch_idx + 1) % self.accumulate_grad_batches == 0
or (self.batch_idx + 1) == self.num_training_batches):

# backward
grad_norm_dic = self.run_batch_backward_pass(split_batch, batch_idx, opt_idx, optimizer)
Expand Down
72 changes: 65 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
import types
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch

Expand Down Expand Up @@ -196,7 +197,7 @@ def test_gradient_accumulation_scheduling(tmpdir, schedule, expected):

trainer = Trainer(
accumulate_grad_batches=schedule,
limit_train_batches=0.8,
limit_train_batches=0.7, # not to be divisible by accumulate_grad_batches on purpose
limit_val_batches=0.8,
max_epochs=4,
default_root_dir=tmpdir,
Expand All @@ -216,8 +217,15 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[0]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]
# separate check for last batch with accumulate 1 step
if expected[0] == 1 and (batch_idx + 1) == trainer.num_training_batches:
assert batch_idx == model.prev_called_batch_idx
elif (batch_idx + 1) == trainer.num_training_batches:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[0] + (batch_idx + 1) % expected[0])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[0]

elif 1 <= epoch <= 2:
# reset counter when starting epoch
Expand All @@ -227,8 +235,12 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[1]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]
if trainer.num_training_batches == batch_idx + 1:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[1] + (batch_idx + 1) % expected[1])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[1]

else:
if batch_idx == expected[2] - 1:
Expand All @@ -237,8 +249,12 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
# use this opportunity to test once
assert trainer.accumulate_grad_batches == expected[2]

assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]
if (batch_idx + 1) == trainer.num_training_batches:
# prev_called_batch_idx - schedule + modulus remainder
assert batch_idx == (model.prev_called_batch_idx - expected[2] + (batch_idx + 1) % expected[2])
else:
assert batch_idx == model.prev_called_batch_idx
model.prev_called_batch_idx += expected[2]

optimizer.step()

Expand All @@ -252,6 +268,48 @@ def _optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
trainer.fit(model)


@pytest.mark.parametrize(
['accumulate_grad_batches', 'limit_train_batches'],
[
pytest.param({1: 2, 3: 4}, 1.0),
pytest.param({1: 2, 3: 4}, 0.5), # not to be divisible by accumulate_grad_batches on purpose
pytest.param(3, 1.0),
pytest.param(3, 0.8), # not to be divisible by accumulate_grad_batches on purpose
pytest.param(4, 1.0),
pytest.param(4, 0.7), # not to be divisible by accumulate_grad_batches on purpose
],
)
def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_batches, limit_train_batches):
""" Verify optimizer.step() applied to last batch while grad accumulation """

class CurrentModel(EvalModelTemplate):
def on_after_backward(self):
self.loss_backward = deepcopy(self.state_dict())

def on_before_zero_grad(self, optimizer):
self.opt_step = self.state_dict()

def on_train_batch_end(self, batch, batch_idx, dataloader_idx):
_exclude_keys = ['num_batches_tracked', 'running_mean', 'running_var']

if (batch_idx + 1) == self.trainer.num_training_batches:
for key in self.loss_backward.keys():
# exclude the check for batch_norm parameters
if not any([k in key for k in _exclude_keys]):
assert not torch.equal(self.loss_backward[key], self.opt_step[key])

model = CurrentModel()

trainer = Trainer(
accumulate_grad_batches=accumulate_grad_batches,
max_epochs=4,
limit_train_batches=limit_train_batches,
default_root_dir=tmpdir
)

trainer.fit(model)


def test_loading_meta_tags(tmpdir):
""" test for backward compatibility to meta_tags.csv """
tutils.reset_seed()
Expand Down

0 comments on commit 73ebd10

Please sign in to comment.