From 2d3988af02fffb2c2906fac7ec459d7e1329635e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 4 May 2020 15:19:14 +0200 Subject: [PATCH] update based on review --- docs/source/training_tricks.rst | 18 +++++------ pytorch_lightning/trainer/trainer.py | 8 +++-- pytorch_lightning/trainer/training_tricks.py | 34 +++++++++----------- tests/trainer/test_trainer_tricks.py | 7 ++-- 4 files changed, 33 insertions(+), 34 deletions(-) diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst index f9acaaf0f83d88..cd28c429b4b209 100644 --- a/docs/source/training_tricks.rst +++ b/docs/source/training_tricks.rst @@ -38,7 +38,8 @@ norm `_ Auto scaling of batch size ------------------------------------- Auto scaling of batch size may be enabled to find the largest batch size that fits into -memory. Larger batch size often give better estimates of +memory. Larger batch size often give better estimates of gradients, but may also give +longer training time. .. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` @@ -58,11 +59,11 @@ size by duing a binary search. .. note:: This feature expects that a `batch_size` field exist in the `hparams` of your model i.e. - `model.hparams.batch_size` should exist and will be overriden by the results of this - algorithm. Settin + `model.hparams.batch_size` should exist and will be overridden by the results of this + algorithm. The scaling algorithm has a number of parameters, that the user can control by -invoking the `.scale_batch_size` method themself. +invoking the trainer method `.scale_batch_size` themself. .. code-block:: python @@ -70,7 +71,7 @@ invoking the `.scale_batch_size` method themself. trainer = Trainer() # Invoke method - new_batch_size = trainer.scale_batch_size(...) + new_batch_size = trainer.scale_batch_size(model, ...) # Override old batch size model.hparams.batch_size = new_batch_size @@ -78,7 +79,6 @@ invoking the `.scale_batch_size` method themself. # Fit as normal trainer.fit(model) -Below - -.. autoclass: - +.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin + :members: scale_batch_size + :noindex: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ddbdba529a4b4c..e1109b80df6f21 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -298,7 +298,9 @@ def __init__( auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The results will be stored in self.hparams.batch_size in the lightning module. - To use a different key, set a string instead of True with the key name. + Additionally, can be set to either `power` (same as `True`) that + estimates the batch size through a power search or `binseach` that + estimates the batch size through a binary search. """ # Init callbacks @@ -745,7 +747,9 @@ def fit( # Run auto batch size scaling if self.auto_scale_batch_size: - _ = self.scale_batch_size(model) + if self.auto_scale_batch_size is True: + self.auto_scale_batch_size = 'power' + _ = self.scale_batch_size(model, mode = self.auto_scale_batch_size) # Run learning rate finder: if self.auto_lr_find: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index b0f5efa0550a27..1542790e1d5e00 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -92,9 +92,10 @@ def scale_batch_size(self, mode: str = 'power', n_step_per_try: int = 3, init_val: int = 2, - n_max_try: int = 20): - r""" Will iteratively try to find the largest batch size for a given model - that does not not give an out of memory (OOM) error + n_max_try: int = 25): + r""" + Will iteratively try to find the largest batch size for a given model + that does not not give an out of memory (OOM) error Args: model: Model to fit. @@ -110,9 +111,9 @@ def scale_batch_size(self, Idealy 1 should be enough to test if a OOM error occurs, however in practise a few is needed - init_val: initial batch size to do the search from + init_val: initial batch size to do the search from - n_max_try: max number of increase/decreases in batch size done before + n_max_try: max number of increase in batch size done before algorithm is terminated """ @@ -121,7 +122,6 @@ def scale_batch_size(self, ' can only be `power` or `binsearch') # Arguments we adjust during the batch size finder, save for restoring - trainer_arg = self.auto_scale_batch_size max_steps = self.max_steps weights_summary = self.weights_summary logger = self.logger @@ -142,7 +142,7 @@ def scale_batch_size(self, self.save_checkpoint(str(save_path)) # Initially we just double in size until an OOM is encountered - new_size = _adjust_batch_size(self, trainer_arg, value=init_val) # initially set to init_val + new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val high = None count = 0 while True: @@ -157,7 +157,7 @@ def scale_batch_size(self, # Double in size low = new_size - new_size = _adjust_batch_size(self, trainer_arg, factor=2.0, string='succeeded') + new_size = _adjust_batch_size(self, factor=2.0, string='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_OOM_error(exception): @@ -165,7 +165,7 @@ def scale_batch_size(self, garbage_collection_cuda() high = new_size if mode != 'binsearch': - new_size = _adjust_batch_size(self, trainer_arg, factor=0.5, string='failed') + new_size = _adjust_batch_size(self, factor=0.5, string='failed') break else: raise # some other error not memory related @@ -185,7 +185,7 @@ def scale_batch_size(self, # Adjust batch size low = new_size midval = (high + low) // 2 - new_size = _adjust_batch_size(self, trainer_arg, value=midval, string='succeeded') + new_size = _adjust_batch_size(self, value=midval, string='succeeded') except RuntimeError as exception: # Only these errors should trigger an adjustment if is_OOM_error(exception): @@ -194,7 +194,7 @@ def scale_batch_size(self, if high - low <= 1: break midval = (high + low) // 2 - new_size = _adjust_batch_size(self, trainer_arg, value=midval, string='failed') + new_size = _adjust_batch_size(self, value=midval, string='failed') else: raise # some other error not memory related @@ -216,20 +216,16 @@ def scale_batch_size(self, def _adjust_batch_size(trainer, - trainer_arg: str, factor: float = 1.0, value: Optional[int] = None, string: str = None): - """ Function for adjusting the batch size + """ Function for adjusting the batch size. It is expected that the user + has provided a model that has a hparam field called `batch_size` i.e. + `model.hparams.batch_size` should exist. Args: trainer: instance of pytorch_lightning.Trainer - trainer_arg: trainer_arg, either True or a string, determines location - to save value of batch size. If True, will save the newly calculated - batch size to `model.hparams.batch_size`. If a string will save - value to `model.hparams.string` - factor: value which the old batch size is multiplied by to get the new batch size @@ -239,7 +235,7 @@ def _adjust_batch_size(trainer, string: either `succeeded` or `failed`. Used purely for logging """ - trainer_arg = trainer_arg if isinstance(trainer_arg, str) else 'batch_size' + trainer_arg = 'batch_size' model = trainer.get_model() diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index a109294eee2a54..495237d8dc91c1 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -115,18 +115,17 @@ class CurrentTestModel( pass hparams = tutils.get_default_hparams() - hparams.__dict__['my_fancy_batch_size'] = 17 # update with non-standard field model = CurrentTestModel(hparams) - before_batch_size = hparams.my_fancy_batch_size + before_batch_size = hparams.batch_size # logger file to get meta trainer = Trainer( default_save_path=tmpdir, max_epochs=1, - auto_lr_find='my_fancy_batch_size' + auto_scale_batch_size='binsearch' ) trainer.fit(model) - after_batch_size = model.hparams.my_fancy_batch_size + after_batch_size = model.hparams.batch_size assert before_batch_size != after_batch_size, \ 'Batch size was not altered after running auto scaling of batch size'