Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: auto scale batch size #1638

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
fbacbaf
auto batch finder
Apr 27, 2020
863bce9
fix styling
Apr 27, 2020
b004826
add description
Apr 27, 2020
a83417e
add different modes
Apr 27, 2020
78d31eb
fix copy paste error
Apr 27, 2020
55b6739
better organised code
Apr 29, 2020
0df9435
fix styling
Apr 29, 2020
6371521
add tests
Apr 29, 2020
ba98712
fix
Apr 29, 2020
c776fde
fix
Apr 29, 2020
62a7432
add some documentation
Apr 29, 2020
1d1a06d
added CHANGELOG.md
Apr 29, 2020
c46c0a7
some documentation
May 4, 2020
aa855fc
update based on review
May 4, 2020
736a8b7
Update trainer.py
williamFalcon May 5, 2020
c6bd574
Update docs/source/training_tricks.rst
williamFalcon May 5, 2020
9f1d6b2
Update tests/trainer/test_trainer_tricks.py
williamFalcon May 5, 2020
7dfb6c2
Update tests/trainer/test_trainer_tricks.py
williamFalcon May 5, 2020
b770f9c
Apply suggestions from code review
Borda May 5, 2020
84bc362
use EvalModelTemplate
Borda May 6, 2020
28779ac
param tests
Borda May 6, 2020
929ca6f
rename
Borda May 6, 2020
5131ade
wrap params
Borda May 6, 2020
248f7d2
rename function
May 6, 2020
01acea9
Merge remote-tracking branch 'origin/feature/auto_batch_size' into fe…
May 6, 2020
f309bc8
rename
May 6, 2020
e5305ee
rename param
Borda May 6, 2020
67e4324
Merge branch 'feature/auto_batch_size' of https://github.com/SkafteNi…
Borda May 6, 2020
000c5a9
fix
Borda May 7, 2020
3a66219
abs
Borda May 7, 2020
9b690a4
rename
Borda May 7, 2020
f39e6d0
refactor code
May 7, 2020
4a34f7c
add docs
May 7, 2020
e56c927
merge
May 7, 2020
9d3c01e
Merge remote-tracking branch 'upstream/master' into feature/auto_batc…
May 7, 2020
33701eb
try
Borda May 7, 2020
0572fff
arg
Borda May 7, 2020
2d0413e
loop
Borda May 7, 2020
ce67950
exept
Borda May 7, 2020
cb490eb
loop
Borda May 7, 2020
83dce03
drop bool
Borda May 7, 2020
a71417e
docs
Borda May 7, 2020
0f9d195
docs
May 8, 2020
289bd34
refactor code a bit more
May 8, 2020
20e2b86
added check and test for passing dataloader to fit
May 8, 2020
522a206
styling fix
May 8, 2020
c3cdd6c
update based on review
May 8, 2020
3122cb4
Merge branch 'master' into feature/auto_batch_size
Borda May 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)).

- Added auto scaling of batch size ([#1638](https://github.com/PyTorchLightning/pytorch-lightning/pull/1638))

- The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)).

### Changed
Expand Down
73 changes: 73 additions & 0 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,76 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_

# clip gradients with norm above 0.5
trainer = Trainer(gradient_clip_val=0.5)

Auto scaling of batch size
--------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into
memory. Larger batch size often yields better estimates of gradients, but may also result in
longer training time.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

.. code-block:: python

# DEFAULT (ie: don't scale batch size automatically)
trainer = Trainer(auto_scale_batch_size=None)

# Autoscale batch size
trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch')

Currently, this feature supports two modes `'power'` scaling and `'binsearch'`
scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling
the batch size until an out-of-memory (OOM) error is encountered. Setting the
argument to `'binsearch'` continues to finetune the batch size by performing
a binary search.

.. note::

This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm. Additionally, your `train_dataloader()` method should depend on this field
for this feature to work i.e.

.. code-block:: python

def train_dataloader(self):
return DataLoader(train_dataset, batch_size=self.hparams.batch_size)

.. warning::

Due to these contrains, this features does *NOT* work when passing dataloaders directly
to `.fit()`.

The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself (see description below).

.. code-block:: python

# Use default in trainer construction
trainer = Trainer()

# Invoke method
new_batch_size = trainer.scale_batch_size(model, ...)

# Override old batch size
model.hparams.batch_size = new_batch_size

# Fit as normal
trainer.fit(model)

The algorithm in short works by:
1. Dumping the current state of the model and trainer
2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached:
- Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of
training steps. Each training step can trigger an OOM error if the tensors
(training batch, weights, gradients ect.) allocated during the steps have a
too large memory footprint.
- If an OOM error is encountered, decrease batch size else increase it.
How much the batch size is increased/decreased is determined by the choosen
stratrgy.
3. The found batch size is saved to `model.hparams.batch_size`
4. Restore the initial state of model and trainer

.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
:members: scale_batch_size
:noindex:
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ def forward(self, x):
# default used by the Trainer
trainer = Trainer(amp_level='O1')
auto_scale_batch_size
^^^^^^^^^^^^^^^^^^^^^
Automatically tries to find the largest batch size that fits into memory,
before any training.
.. code-block:: python
# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=None)
# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size='binsearch')
auto_lr_find
^^^^^^^^^^^^
Runs a learning rate finder algorithm (see this `paper <https://arxiv.org/abs/1506.01186>`_)
Expand Down
33 changes: 17 additions & 16 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def lr_find(self,
"""
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')

self._lr_finder_dump_params(model)
self.__lr_finder_dump_params(model)

# Prevent going into infinite loop
self.auto_lr_find = False
Expand Down Expand Up @@ -170,15 +170,15 @@ def lr_find(self,
os.remove(save_path)

# Finish by resetting variables so trainer is ready to fit model
self._lr_finder_restore_params(model)
self.__lr_finder_restore_params(model)
if self.progress_bar_callback:
self.progress_bar_callback.enable()

return lr_finder

def _lr_finder_dump_params(self, model):
def __lr_finder_dump_params(self, model):
# Prevent going into infinite loop
self._params = {
self.__dumped_params = {
'auto_lr_find': self.auto_lr_find,
'callbacks': self.callbacks,
'logger': self.logger,
Expand All @@ -192,18 +192,19 @@ def _lr_finder_dump_params(self, model):
'configure_optimizers': model.configure_optimizers,
}

def _lr_finder_restore_params(self, model):
self.auto_lr_find = self._params['auto_lr_find']
self.logger = self._params['logger']
self.callbacks = self._params['callbacks']
self.max_steps = self._params['max_steps']
self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self._params['accumulate_grad_batches']
self.checkpoint_callback = self._params['checkpoint_callback']
self.early_stop_callback = self._params['early_stop_callback']
self.enable_early_stop = self._params['enable_early_stop']
self.progress_bar_callback = self._params['progress_bar_callback']
model.configure_optimizers = self._params['configure_optimizers']
def __lr_finder_restore_params(self, model):
self.auto_lr_find = self.__dumped_params['auto_lr_find']
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate']
self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
self.enable_early_stop = self.__dumped_params['enable_early_stop']
self.progress_bar_callback = self.__dumped_params['progress_bar_callback']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params


class _LRFinder(object):
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
auto_scale_batch_size: Optional[str] = None,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -293,6 +294,12 @@ def __init__(

terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the
end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The result will be stored in self.hparams.batch_size in the LightningModule.
Additionally, can be set to either `power` that estimates the batch size through
a power search or `binsearch` that estimates the batch size through a binary search.
"""

# Init callbacks
Expand Down Expand Up @@ -368,6 +375,7 @@ def __init__(
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
Expand Down Expand Up @@ -474,7 +482,7 @@ def __init__(
self.show_progress_bar = show_progress_bar

self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar_callback = None
self.progress_bar_callback = progress_bar_callback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the arg progress_bar_callback was not used anywhere - forgotten, hope it is the right place...

self.configure_progress_bar()

# logging
Expand Down Expand Up @@ -736,6 +744,10 @@ def fit(
# only on proc 0 because no spawn has happened yet
model.prepare_data()

# Run auto batch size scaling
if self.auto_scale_batch_size:
self.scale_batch_size(model, mode=self.auto_scale_batch_size)

# Run learning rate finder:
if self.auto_lr_find:
self._run_lr_finder_internally(model)
Expand Down
Loading