Skip to content

Commit

Permalink
limit auto scaling batch size to the size of the training dataset (#3271
Browse files Browse the repository at this point in the history
)

* fix

* fix and test

* fix merge error

* test for max dataset size

* changelog

* update docs

* fix merge

* unused imports

* imports
  • Loading branch information
awaelchli authored Sep 9, 2020
1 parent 0c2e315 commit e245065
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 23 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed Horovod distributed backend compatibility with native AMP ([#3404](https://github.com/PyTorchLightning/pytorch-lightning/pull/3404))

- Fixed batch size auto scaling exceeding the size of the dataset ([#3271](https://github.com/PyTorchLightning/pytorch-lightning/pull/3271))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
from abc import ABC, abstractmethod
from typing import Optional

import torch
from torch import Tensor

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr

try:
from apex import amp
Expand Down
52 changes: 38 additions & 14 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License
import os
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.data import has_len
from pytorch_lightning.utilities.parsing import lightning_hasattr, lightning_getattr, lightning_setattr
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning import _logger as log
from typing import Optional
from typing import Optional, Tuple


def scale_batch_size(trainer,
Expand Down Expand Up @@ -55,6 +56,13 @@ def scale_batch_size(trainer,
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand Down Expand Up @@ -165,16 +173,19 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **f
# Try fit
trainer.fit(model, **fit_kwargs)
# Double in size
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
break
else:
raise # some other error not memory related

if not changed:
break
return new_size


Expand All @@ -199,39 +210,40 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials,
if high - low <= 1:
break
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
else:
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')

if not changed:
break

except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_oom_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
high = new_size
midval = (high + low) // 2
new_size = _adjust_batch_size(trainer, value=midval, desc='failed')
new_size, _ = _adjust_batch_size(trainer, value=midval, desc='failed')
if high - low <= 1:
break
else:
raise # some other error not memory related

return new_size


def _adjust_batch_size(trainer,
batch_arg_name: str = 'batch_size',
factor: float = 1.0,
value: Optional[int] = None,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in that case the attribute
also gets updated when present.
desc: str = None) -> Tuple[int, bool]:
""" Helper function for adjusting the batch size.
Args:
trainer: instance of pytorch_lightning.Trainer
batch_arg_name: field where batch_size is stored in `model.hparams`
batch_arg_name: name of the field where batch_size is stored.
factor: value which the old batch size is multiplied by to get the
new batch size
Expand All @@ -241,11 +253,23 @@ def _adjust_batch_size(trainer,
desc: either `succeeded` or `failed`. Used purely for logging
Returns:
The new batch size for the next trial and a bool that signals whether the
new value is different than the previous batch size.
"""
model = trainer.get_model()
batch_size = lightning_getattr(model, batch_arg_name)
new_size = value if value is not None else int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')

if not _is_valid_batch_size(new_size, trainer.train_dataloader):
new_size = min(new_size, len(trainer.train_dataloader.dataset))

changed = new_size != batch_size
lightning_setattr(model, batch_arg_name, new_size)
return new_size
return new_size, changed


def _is_valid_batch_size(current_size, dataloader):
return not has_len(dataloader) or current_size <= len(dataloader)
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ def dataloader(self, *args, **kwargs):

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model, datamodule_fit)
assert trainer.datamodule == datamodule_fit
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert trainer.datamodule == datamodule_fit
assert before_batch_size != after_batch_size
assert after_batch_size <= len(trainer.train_dataloader.dataset)
assert datamodule_fit.batch_size == after_batch_size
# should be left unchanged, since it was not passed to .tune()
assert datamodule_model.batch_size == 111
Expand Down

0 comments on commit e245065

Please sign in to comment.