Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
Borda and awaelchli committed May 5, 2020
1 parent b5d6f77 commit 810ecc8
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 19 deletions.
12 changes: 6 additions & 6 deletions docs/source/training_tricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ norm <https://pytorch.org/docs/stable/nn.html#torch.nn.utils.clip_grad_norm_>`_
trainer = Trainer(gradient_clip_val=0.5)

Auto scaling of batch size
-------------------------------------
--------------------------
Auto scaling of batch size may be enabled to find the largest batch size that fits into
memory. Larger batch size often yields better estimates of gradients, but may also result in
longer training time.
Expand All @@ -52,17 +52,17 @@ longer training time.
trainer = Trainer(auto_scale_batch_size=True|'power'|'binsearch')
Setting the feature to `True` enables `'power'` scaling, that starting from a
batch size of 1 keeps double the batch size until an out-of-memory (OMM) error is
encountered. Setting the argument to `'binsearch'` continue to finetune the batch
size by duing a binary search.
batch size of 1 keeps doubling the batch size until an out-of-memory (OOM) error is
encountered. Setting the argument to `'binsearch'` continues to finetune the batch
size by performing a binary search.

.. note::

This feature expects that a `batch_size` field exist in the `hparams` of your model i.e.
This feature expects that a `batch_size` field in the `hparams` of your model, i.e.,
`model.hparams.batch_size` should exist and will be overridden by the results of this
algorithm.

The scaling algorithm has a number of parameters, that the user can control by
The scaling algorithm has a number of parameters that the user can control by
invoking the trainer method `.scale_batch_size` themself.

.. code-block:: python
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@ def forward(self, x):
.. code-block:: python
# defeault use by the Trainer (no scaling of batch size)
# default used by the Trainer (no scaling of batch size)
trainer = Trainer(auto_scale_batch_size=False)
Example::
# run batch size scaling, result override hparams.batch_size
# run batch size scaling, result overrides hparams.batch_size
trainer = Trainer(auto_scale_batch_size=True)
# run batch size scaling, result override hparams.my_batch_size_arg
# run batch size scaling, result overrides hparams.my_batch_size_arg
trainer = Trainer(auto_scale_batch_size='my_batch_size_arg')
auto_lr_find
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,9 @@ def __init__(
auto_scale_batch_size: If set to True, will `initially` run a batch size
finder trying to find the largest batch size that fits into memory.
The results will be stored in self.hparams.batch_size in the lightning module.
The result will be stored in self.hparams.batch_size in the LightningModule.
Additionally, can be set to either `power` (same as `True`) that
estimates the batch size through a power search or `binseach` that
estimates the batch size through a power search or `binsearch` that
estimates the batch size through a binary search.
"""

Expand Down
11 changes: 5 additions & 6 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def scale_batch_size(self,
n_max_try: int = 25):
r"""
Will iteratively try to find the largest batch size for a given model
that does not not give an out of memory (OOM) error
that does not give an out of memory (OOM) error.
Args:
model: Model to fit.
Expand All @@ -104,22 +104,21 @@ def scale_batch_size(self,
If mode is `power` we keep multiplying the batch size by 2, until
we get an OOM error. If mode is 'binsearch', we will initially
also keep multiplying by 2 and after encountering an OOM error
do a binary search between the last succeded batch size and the
do a binary search between the last successful batch size and the
batch size that failed.
n_step_per_try: number of steps to run with a given batch size.
Idealy 1 should be enough to test if a OOM error occurs,
however in practise a few is needed
however in practise a few are needed
init_val: initial batch size to do the search from
init_val: initial batch size to start the search with
n_max_try: max number of increase in batch size done before
algorithm is terminated
"""
if mode not in ['power', 'binsearch']:
raise ValueError('mode in method `scale_batch_size`'
' can only be `power` or `binsearch')
raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')

# Arguments we adjust during the batch size finder, save for restoring
max_steps = self.max_steps
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class CurrentTestModel(


def test_trainer_reset_correctly(tmpdir):
''' Check that all trainer parameters are reset correctly after scaling batch size'''
""" Check that all trainer parameters are reset correctly after scaling batch size. """
tutils.reset_seed()

class CurrentTestModel(
Expand Down Expand Up @@ -79,7 +79,7 @@ class CurrentTestModel(


def test_trainer_arg_bool(tmpdir):
''' Check that trainer arg works with bool input '''
""" Check that trainer arg works with bool input. """
tutils.reset_seed()

class CurrentTestModel(
Expand Down

0 comments on commit 810ecc8

Please sign in to comment.