Skip to content

Commit

Permalink
update based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicki Skafte authored and Borda committed May 5, 2020
1 parent 8bd07a6 commit 2d3988a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 34 deletions.
18 changes: 9 additions & 9 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
Auto scaling of batch size
-------------------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into
memory. Larger batch size often give better estimates of
memory. Larger batch size often give better estimates of gradients, but may also give
longer training time.

.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer`

Expand All @@ -58,27 +59,26 @@ size by duing a binary search.
.. note::

This feature expects that a `batch_size` field exist in the `hparams` of your model i.e.
`model.hparams.batch_size` should exist and will be overriden by the results of this
algorithm. Settin
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm.

The scaling algorithm has a number of parameters, that the user can control by
invoking the `.scale_batch_size` method themself.
invoking the trainer method `.scale_batch_size` themself.

.. code-block:: python
# Use default in trainer construction
trainer = Trainer()
# Invoke method
new_batch_size = trainer.scale_batch_size(...)
new_batch_size = trainer.scale_batch_size(model, ...)
# Override old batch size
model.hparams.batch_size = new_batch_size
# Fit as normal
trainer.fit(model)
Below

.. autoclass:
.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
:members: scale_batch_size
:noindex:
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,9 @@ def __init__(
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The results will be stored in self.hparams.batch_size in the lightning module.
To use a different key, set a string instead of True with the key name.
Additionally, can be set to either `power` (same as `True`) that
estimates the batch size through a power search or `binseach` that
estimates the batch size through a binary search.
"""

# Init callbacks
Expand Down Expand Up @@ -745,7 +747,9 @@ def fit(

# Run auto batch size scaling
if self.auto_scale_batch_size:
_ = self.scale_batch_size(model)
if self.auto_scale_batch_size is True:
self.auto_scale_batch_size = 'power'
_ = self.scale_batch_size(model, mode = self.auto_scale_batch_size)

# Run learning rate finder:
if self.auto_lr_find:
Expand Down
34 changes: 15 additions & 19 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def scale_batch_size(self,
mode: str = 'power',
n_step_per_try: int = 3,
init_val: int = 2,
n_max_try: int = 20):
r""" Will iteratively try to find the largest batch size for a given model
that does not not give an out of memory (OOM) error
n_max_try: int = 25):
r"""
Will iteratively try to find the largest batch size for a given model
that does not not give an out of memory (OOM) error
Args:
model: Model to fit.
Expand All @@ -110,9 +111,9 @@ def scale_batch_size(self,
Idealy 1 should be enough to test if a OOM error occurs,
however in practise a few is needed
init_val: initial batch size to do the search from
init_val: initial batch size to do the search from
n_max_try: max number of increase/decreases in batch size done before
n_max_try: max number of increase in batch size done before
algorithm is terminated
"""
Expand All @@ -121,7 +122,6 @@ def scale_batch_size(self,
' can only be `power` or `binsearch')

# Arguments we adjust during the batch size finder, save for restoring
trainer_arg = self.auto_scale_batch_size
max_steps = self.max_steps
weights_summary = self.weights_summary
logger = self.logger
Expand All @@ -142,7 +142,7 @@ def scale_batch_size(self,
self.save_checkpoint(str(save_path))

# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(self, trainer_arg, value=init_val) # initially set to init_val
new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val
high = None
count = 0
while True:
Expand All @@ -157,15 +157,15 @@ def scale_batch_size(self,

# Double in size
low = new_size
new_size = _adjust_batch_size(self, trainer_arg, factor=2.0, string='succeeded')
new_size = _adjust_batch_size(self, factor=2.0, string='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_OOM_error(exception):
# If we fail in power mode, half the size and return
garbage_collection_cuda()
high = new_size
if mode != 'binsearch':
new_size = _adjust_batch_size(self, trainer_arg, factor=0.5, string='failed')
new_size = _adjust_batch_size(self, factor=0.5, string='failed')
break
else:
raise # some other error not memory related
Expand All @@ -185,7 +185,7 @@ def scale_batch_size(self,
# Adjust batch size
low = new_size
midval = (high + low) // 2
new_size = _adjust_batch_size(self, trainer_arg, value=midval, string='succeeded')
new_size = _adjust_batch_size(self, value=midval, string='succeeded')
except RuntimeError as exception:
# Only these errors should trigger an adjustment
if is_OOM_error(exception):
Expand All @@ -194,7 +194,7 @@ def scale_batch_size(self,
if high - low <= 1:
break
midval = (high + low) // 2
new_size = _adjust_batch_size(self, trainer_arg, value=midval, string='failed')
new_size = _adjust_batch_size(self, value=midval, string='failed')
else:
raise # some other error not memory related

Expand All @@ -216,20 +216,16 @@ def scale_batch_size(self,


def _adjust_batch_size(trainer,
trainer_arg: str,
factor: float = 1.0,
value: Optional[int] = None,
string: str = None):
""" Function for adjusting the batch size
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist.
Args:
trainer: instance of pytorch_lightning.Trainer
trainer_arg: trainer_arg, either True or a string, determines location
to save value of batch size. If True, will save the newly calculated
batch size to `model.hparams.batch_size`. If a string will save
value to `model.hparams.string`
factor: value which the old batch size is multiplied by to get the
new batch size
Expand All @@ -239,7 +235,7 @@ def _adjust_batch_size(trainer,
string: either `succeeded` or `failed`. Used purely for logging
"""
trainer_arg = trainer_arg if isinstance(trainer_arg, str) else 'batch_size'
trainer_arg = 'batch_size'

model = trainer.get_model()

Expand Down
7 changes: 3 additions & 4 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,17 @@ class CurrentTestModel(
pass

hparams = tutils.get_default_hparams()
hparams.__dict__['my_fancy_batch_size'] = 17 # update with non-standard field
model = CurrentTestModel(hparams)
before_batch_size = hparams.my_fancy_batch_size
before_batch_size = hparams.batch_size
# logger file to get meta
trainer = Trainer(
default_save_path=tmpdir,
max_epochs=1,
auto_lr_find='my_fancy_batch_size'
auto_scale_batch_size='binsearch'
)

trainer.fit(model)
after_batch_size = model.hparams.my_fancy_batch_size
after_batch_size = model.hparams.batch_size
assert before_batch_size != after_batch_size, \
'Batch size was not altered after running auto scaling of batch size'

Expand Down

0 comments on commit 2d3988a

Please sign in to comment.