From 810ecc8125b19baeeab77c4cde5968bf6556f858 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 5 May 2020 23:07:01 +0200 Subject: [PATCH] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/training_tricks.rst | 12 ++++++------ pytorch_lightning/trainer/__init__.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/trainer/training_tricks.py | 11 +++++------ tests/trainer/test_trainer_tricks.py | 4 ++-- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst index 485c2e92772707..65a300adf74fc3 100644 --- a/docs/source/training_tricks.rst +++ b/docs/source/training_tricks.rst @@ -36,7 +36,7 @@ norm `_ trainer = Trainer(gradient_clip_val=0.5) Auto scaling of batch size -------------------------------------- +-------------------------- Auto scaling of batch size may be enabled to find the largest batch size that fits into memory. Larger batch size often yields better estimates of gradients, but may also result in longer training time. @@ -52,17 +52,17 @@ longer training time. trainer = Trainer(auto_scale_batch_size=True|'power'|'binsearch') Setting the feature to `True` enables `'power'` scaling, that starting from a -batch size of 1 keeps double the batch size until an out-of-memory (OMM) error is -encountered. Setting the argument to `'binsearch'` continue to finetune the batch -size by duing a binary search. +batch size of 1 keeps doubling the batch size until an out-of-memory (OOM) error is +encountered. Setting the argument to `'binsearch'` continues to finetune the batch +size by performing a binary search. .. note:: - This feature expects that a `batch_size` field exist in the `hparams` of your model i.e. + This feature expects that a `batch_size` field in the `hparams` of your model, i.e., `model.hparams.batch_size` should exist and will be overridden by the results of this algorithm. -The scaling algorithm has a number of parameters, that the user can control by +The scaling algorithm has a number of parameters that the user can control by invoking the trainer method `.scale_batch_size` themself. .. code-block:: python diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index cce58c3cdc6ebd..6b7d6ee99f17ff 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -142,15 +142,15 @@ def forward(self, x): .. code-block:: python - # defeault use by the Trainer (no scaling of batch size) + # default used by the Trainer (no scaling of batch size) trainer = Trainer(auto_scale_batch_size=False) Example:: - # run batch size scaling, result override hparams.batch_size + # run batch size scaling, result overrides hparams.batch_size trainer = Trainer(auto_scale_batch_size=True) - # run batch size scaling, result override hparams.my_batch_size_arg + # run batch size scaling, result overrides hparams.my_batch_size_arg trainer = Trainer(auto_scale_batch_size='my_batch_size_arg') auto_lr_find diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6ab86d6390abec..0d83234025f4c8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -297,9 +297,9 @@ def __init__( auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. - The results will be stored in self.hparams.batch_size in the lightning module. + The result will be stored in self.hparams.batch_size in the LightningModule. Additionally, can be set to either `power` (same as `True`) that - estimates the batch size through a power search or `binseach` that + estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. """ diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 1542790e1d5e00..199a30887e64ca 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -95,7 +95,7 @@ def scale_batch_size(self, n_max_try: int = 25): r""" Will iteratively try to find the largest batch size for a given model - that does not not give an out of memory (OOM) error + that does not give an out of memory (OOM) error. Args: model: Model to fit. @@ -104,22 +104,21 @@ def scale_batch_size(self, If mode is `power` we keep multiplying the batch size by 2, until we get an OOM error. If mode is 'binsearch', we will initially also keep multiplying by 2 and after encountering an OOM error - do a binary search between the last succeded batch size and the + do a binary search between the last successful batch size and the batch size that failed. n_step_per_try: number of steps to run with a given batch size. Idealy 1 should be enough to test if a OOM error occurs, - however in practise a few is needed + however in practise a few are needed - init_val: initial batch size to do the search from + init_val: initial batch size to start the search with n_max_try: max number of increase in batch size done before algorithm is terminated """ if mode not in ['power', 'binsearch']: - raise ValueError('mode in method `scale_batch_size`' - ' can only be `power` or `binsearch') + raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') # Arguments we adjust during the batch size finder, save for restoring max_steps = self.max_steps diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index e7c93de65997a8..f378b9e3c261ff 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -42,7 +42,7 @@ class CurrentTestModel( def test_trainer_reset_correctly(tmpdir): - ''' Check that all trainer parameters are reset correctly after scaling batch size''' + """ Check that all trainer parameters are reset correctly after scaling batch size. """ tutils.reset_seed() class CurrentTestModel( @@ -79,7 +79,7 @@ class CurrentTestModel( def test_trainer_arg_bool(tmpdir): - ''' Check that trainer arg works with bool input ''' + """ Check that trainer arg works with bool input. """ tutils.reset_seed() class CurrentTestModel(