Skip to content

Commit

Permalink
Feature: auto scale batch size (#1638)
Browse files Browse the repository at this point in the history
* auto batch finder

* fix styling

* add description

* add different modes

* fix copy paste error

* better organised code

* fix styling

* add tests

* fix

* fix

* add some documentation

* added CHANGELOG.md

* some documentation

* update based on review

* Update trainer.py

* Update docs/source/training_tricks.rst

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update tests/trainer/test_trainer_tricks.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update tests/trainer/test_trainer_tricks.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* use EvalModelTemplate

* param tests

* rename

* wrap params

* rename function

* rename

* rename param

* fix

* abs

* rename

* refactor code

* add docs

* try

* arg

* loop

* exept

* loop

* drop bool

* docs

* docs

* added check and test for passing dataloader to fit

* styling fix

* update based on review

Co-authored-by: Nicki Skafte <nugginea@gmail.com>
Co-authored-by: William Falcon <waf2107@columbia.edu>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
6 people committed May 9, 2020
1 parent 25bbd05 commit 4970927
Show file tree
Hide file tree
Showing 8 changed files with 519 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).

- Added auto scaling of batch size ([#1638](https://github.com/PyTorchLightning/pytorch-lightning/pull/1638))

- The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)).

### Changed
Expand Down
73 changes: 73 additions & 0 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,76 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

Auto scaling of batch size
--------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into
memory. Larger batch size often yields better estimates of gradients, but may also result in
longer training time.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

.. code-block:: python
# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)
# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')
Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` continues to finetune the batch size by performing
a binary search.

.. note::

This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm. Additionally, your `train_dataloader()` method should depend on this field
for this feature to work i.e.

.. code-block:: python
def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)
.. warning::

Due to these contrains, this features does *NOT* work when passing dataloaders directly
to `.fit()`.

The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).

.. code-block:: python
# Use default in trainer construction
trainer = Trainer()
# Invoke method
new_batch_size = trainer.scale_batch_size(model, ...)
# Override old batch size
model.hparams.batch_size = new_batch_size
# Fit as normal
trainer.fit(model)
The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached:
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
too large memory footprint.
- If an OOM error is encountered, decrease batch size else increase it.
How much the batch size is increased/decreased is determined by the choosen
stratrgy.
3. The found batch size is saved to `model.hparams.batch_size`
4. Restore the initial state of model and trainer

.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
:members: scale_batch_size
:noindex:
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ def forward(self, x):
# default used by the Trainer
trainer = Trainer(amp_level='O1')
auto_scale_batch_size
^^^^^^^^^^^^^^^^^^^^^
Automatically tries to find the largest batch size that fits into memory,
before any training.
.. code-block:: python
# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)
# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size='binsearch')
auto_lr_find
^^^^^^^^^^^^
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
Expand Down
33 changes: 17 additions & 16 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def lr_find(self,
"""
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self._lr_finder_dump_params(model)
self.__lr_finder_dump_params(model)

# Prevent going into infinite loop
self.auto_lr_find = False
Expand Down Expand Up @@ -170,15 +170,15 @@ def lr_find(self,
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
self._lr_finder_restore_params(model)
self.__lr_finder_restore_params(model)
if self.progress_bar_callback:
self.progress_bar_callback.enable()

return lr_finder

def _lr_finder_dump_params(self, model):
def __lr_finder_dump_params(self, model):
# Prevent going into infinite loop
self._params = {
self.__dumped_params = {
'auto_lr_find': self.auto_lr_find,
'callbacks': self.callbacks,
'logger': self.logger,
Expand All @@ -192,18 +192,19 @@ def _lr_finder_dump_params(self, model):
'configure_optimizers': model.configure_optimizers,
}

def _lr_finder_restore_params(self, model):
self.auto_lr_find = self._params['auto_lr_find']
self.logger = self._params['logger']
self.callbacks = self._params['callbacks']
self.max_steps = self._params['max_steps']
self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self._params['accumulate_grad_batches']
self.checkpoint_callback = self._params['checkpoint_callback']
self.early_stop_callback = self._params['early_stop_callback']
self.enable_early_stop = self._params['enable_early_stop']
self.progress_bar_callback = self._params['progress_bar_callback']
model.configure_optimizers = self._params['configure_optimizers']
def __lr_finder_restore_params(self, model):
self.auto_lr_find = self.__dumped_params['auto_lr_find']
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
self.progress_bar_callback = self.__dumped_params['progress_bar_callback']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params


class _LRFinder(object):
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
auto_scale_batch_size: Optional[str] = None,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -293,6 +294,12 @@ def __init__(
terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The result will be stored in self.hparams.batch_size in the LightningModule.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
"""

# Init callbacks
Expand Down Expand Up @@ -368,6 +375,7 @@ def __init__(
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
Expand Down Expand Up @@ -474,7 +482,7 @@ def __init__(
self.show_progress_bar = show_progress_bar

self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar_callback = None
self.progress_bar_callback = progress_bar_callback
self.configure_progress_bar()

# logging
Expand Down Expand Up @@ -736,6 +744,10 @@ def fit(
# only on proc 0 because no spawn has happened yet
model.prepare_data()

# Run auto batch size scaling
if self.auto_scale_batch_size:
self.scale_batch_size(model, mode=self.auto_scale_batch_size)

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
Expand Down
Loading

0 comments on commit 4970927

Please sign in to comment.