diff --git a/CHANGELOG.md b/CHANGELOG.md index 46c2ec88fe52b..f28e258e4fabd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added type hints in `Trainer.fit()` and `Trainer.test()` to reflect that also a list of dataloaders can be passed in ([#1723](https://github.com/PyTorchLightning/pytorch-lightning/pull/1723)). +- Added auto scaling of batch size ([#1638](https://github.com/PyTorchLightning/pytorch-lightning/pull/1638)) + - The progress bar metrics now also get updated in `training_epoch_end` ([#1724](https://github.com/PyTorchLightning/pytorch-lightning/pull/1724)). ### Changed diff --git a/docs/source/training_tricks.rst b/docs/source/training_tricks.rst index e97d7837e0eb4..76f41acea4760 100644 --- a/docs/source/training_tricks.rst +++ b/docs/source/training_tricks.rst @@ -34,3 +34,76 @@ norm `_ # clip gradients with norm above 0.5 trainer = Trainer(gradient_clip_val=0.5) + +Auto scaling of batch size +-------------------------- +Auto scaling of batch size may be enabled to find the largest batch size that fits into +memory. Larger batch size often yields better estimates of gradients, but may also result in +longer training time. + +.. seealso:: :class:`~pytorch_lightning.trainer.trainer.Trainer` + +.. code-block:: python + + # DEFAULT (ie: don't scale batch size automatically) + trainer = Trainer(auto_scale_batch_size=None) + + # Autoscale batch size + trainer = Trainer(auto_scale_batch_size=None|'power'|'binsearch') + +Currently, this feature supports two modes `'power'` scaling and `'binsearch'` +scaling. In `'power'` scaling, starting from a batch size of 1 keeps doubling +the batch size until an out-of-memory (OOM) error is encountered. Setting the +argument to `'binsearch'` continues to finetune the batch size by performing +a binary search. + +.. note:: + + This feature expects that a `batch_size` field in the `hparams` of your model, i.e., + `model.hparams.batch_size` should exist and will be overridden by the results of this + algorithm. Additionally, your `train_dataloader()` method should depend on this field + for this feature to work i.e. + + .. code-block:: python + + def train_dataloader(self): + return DataLoader(train_dataset, batch_size=self.hparams.batch_size) + +.. warning:: + + Due to these contrains, this features does *NOT* work when passing dataloaders directly + to `.fit()`. + +The scaling algorithm has a number of parameters that the user can control by +invoking the trainer method `.scale_batch_size` themself (see description below). + +.. code-block:: python + + # Use default in trainer construction + trainer = Trainer() + + # Invoke method + new_batch_size = trainer.scale_batch_size(model, ...) + + # Override old batch size + model.hparams.batch_size = new_batch_size + + # Fit as normal + trainer.fit(model) + +The algorithm in short works by: + 1. Dumping the current state of the model and trainer + 2. Iteratively until convergence or maximum number of tries `max_trials` (default 25) has been reached: + - Call `fit()` method of trainer. This evaluates `steps_per_trial` (default 3) number of + training steps. Each training step can trigger an OOM error if the tensors + (training batch, weights, gradients ect.) allocated during the steps have a + too large memory footprint. + - If an OOM error is encountered, decrease batch size else increase it. + How much the batch size is increased/decreased is determined by the choosen + stratrgy. + 3. The found batch size is saved to `model.hparams.batch_size` + 4. Restore the initial state of model and trainer + +.. autoclass:: pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin + :members: scale_batch_size + :noindex: diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 42f92979d6430..80f82d917396f 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -135,6 +135,19 @@ def forward(self, x): # default used by the Trainer trainer = Trainer(amp_level='O1') +auto_scale_batch_size +^^^^^^^^^^^^^^^^^^^^^ +Automatically tries to find the largest batch size that fits into memory, +before any training. + +.. code-block:: python + + # default used by the Trainer (no scaling of batch size) + trainer = Trainer(auto_scale_batch_size=None) + + # run batch size scaling, result overrides hparams.batch_size + trainer = Trainer(auto_scale_batch_size='binsearch') + auto_lr_find ^^^^^^^^^^^^ Runs a learning rate finder algorithm (see this `paper `_) diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index 89b1c886a1db9..7de7b298f052b 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -106,7 +106,7 @@ def lr_find(self, """ save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt') - self._lr_finder_dump_params(model) + self.__lr_finder_dump_params(model) # Prevent going into infinite loop self.auto_lr_find = False @@ -170,15 +170,15 @@ def lr_find(self, os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model - self._lr_finder_restore_params(model) + self.__lr_finder_restore_params(model) if self.progress_bar_callback: self.progress_bar_callback.enable() return lr_finder - def _lr_finder_dump_params(self, model): + def __lr_finder_dump_params(self, model): # Prevent going into infinite loop - self._params = { + self.__dumped_params = { 'auto_lr_find': self.auto_lr_find, 'callbacks': self.callbacks, 'logger': self.logger, @@ -192,18 +192,19 @@ def _lr_finder_dump_params(self, model): 'configure_optimizers': model.configure_optimizers, } - def _lr_finder_restore_params(self, model): - self.auto_lr_find = self._params['auto_lr_find'] - self.logger = self._params['logger'] - self.callbacks = self._params['callbacks'] - self.max_steps = self._params['max_steps'] - self.progress_bar_refresh_rate = self._params['progress_bar_refresh_rate'] - self.accumulate_grad_batches = self._params['accumulate_grad_batches'] - self.checkpoint_callback = self._params['checkpoint_callback'] - self.early_stop_callback = self._params['early_stop_callback'] - self.enable_early_stop = self._params['enable_early_stop'] - self.progress_bar_callback = self._params['progress_bar_callback'] - model.configure_optimizers = self._params['configure_optimizers'] + def __lr_finder_restore_params(self, model): + self.auto_lr_find = self.__dumped_params['auto_lr_find'] + self.logger = self.__dumped_params['logger'] + self.callbacks = self.__dumped_params['callbacks'] + self.max_steps = self.__dumped_params['max_steps'] + self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate'] + self.accumulate_grad_batches = self.__dumped_params['accumulate_grad_batches'] + self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] + self.early_stop_callback = self.__dumped_params['early_stop_callback'] + self.enable_early_stop = self.__dumped_params['enable_early_stop'] + self.progress_bar_callback = self.__dumped_params['progress_bar_callback'] + model.configure_optimizers = self.__dumped_params['configure_optimizers'] + del self.__dumped_params class _LRFinder(object): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 56df5ad3a9abf..aadd2bff30f1c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -129,6 +129,7 @@ def __init__( auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, + auto_scale_batch_size: Optional[str] = None, amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -293,6 +294,12 @@ def __init__( terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. + + auto_scale_batch_size: If set to True, will `initially` run a batch size + finder trying to find the largest batch size that fits into memory. + The result will be stored in self.hparams.batch_size in the LightningModule. + Additionally, can be set to either `power` that estimates the batch size through + a power search or `binsearch` that estimates the batch size through a binary search. """ # Init callbacks @@ -368,6 +375,7 @@ def __init__( self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch self.auto_lr_find = auto_lr_find + self.auto_scale_batch_size = auto_scale_batch_size self.replace_sampler_ddp = replace_sampler_ddp self.truncated_bptt_steps = truncated_bptt_steps @@ -474,7 +482,7 @@ def __init__( self.show_progress_bar = show_progress_bar self.progress_bar_refresh_rate = progress_bar_refresh_rate - self.progress_bar_callback = None + self.progress_bar_callback = progress_bar_callback self.configure_progress_bar() # logging @@ -736,6 +744,10 @@ def fit( # only on proc 0 because no spawn has happened yet model.prepare_data() + # Run auto batch size scaling + if self.auto_scale_batch_size: + self.scale_batch_size(model, mode=self.auto_scale_batch_size) + # Run learning rate finder: if self.auto_lr_find: self._run_lr_finder_internally(model) diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 0d86d53b7bbc4..61471a8430631 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -1,12 +1,19 @@ import math import sys from abc import ABC, abstractmethod +import gc +import os +from typing import Optional import torch from torch import Tensor +from torch.utils.data import DataLoader from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -18,11 +25,24 @@ class TrainerTrainingTricksMixin(ABC): # the proper values/initialisation should be done in child class gradient_clip_val: ... precision: ... + on_gpu: bool @abstractmethod def get_model(self): """Warning: this is just empty shell for code implemented in other class.""" + @abstractmethod + def save_checkpoint(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + + @abstractmethod + def restore(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + + @abstractmethod + def fit(self, *args): + """Warning: this is just empty shell for code implemented in other class.""" + def clip_gradients(self): # this code is a modification of torch.nn.utils.clip_grad_norm_ @@ -80,3 +100,219 @@ def configure_accumulated_gradients(self, accumulate_grad_batches): self.accumulation_scheduler = GradientAccumulationScheduler(schedule) else: raise TypeError("Gradient accumulation supports only int and dict types") + + def scale_batch_size(self, + model: LightningModule, + mode: str = 'power', + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + batch_arg_name: str = 'batch_size'): + r""" + Will iteratively try to find the largest batch size for a given model + that does not give an out of memory (OOM) error. + + Args: + model: Model to fit. + + mode: string setting the search mode. Either `power` or `binsearch`. + If mode is `power` we keep multiplying the batch size by 2, until + we get an OOM error. If mode is 'binsearch', we will initially + also keep multiplying by 2 and after encountering an OOM error + do a binary search between the last successful batch size and the + batch size that failed. + + steps_per_trial: number of steps to run with a given batch size. + Idealy 1 should be enough to test if a OOM error occurs, + however in practise a few are needed + + init_val: initial batch size to start the search with + + max_trials: max number of increase in batch size done before + algorithm is terminated + + """ + if not hasattr(model.hparams, batch_arg_name): + raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`') + + if hasattr(model.train_dataloader, 'patch_loader_code'): + raise MisconfigurationException(f'The batch scaling feature cannot be used with dataloaders' + ' passed directly to `.fit()`. Please disable the feature or' + ' incorporate the dataloader into the model.') + + # Arguments we adjust during the batch size finder, save for restoring + self.__scale_batch_dump_params() + + # Set to values that are required by the algorithm + self.__scale_batch_reset_params(model, steps_per_trial) + + # Save initial model, that is loaded after batch size is found + save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt') + self.save_checkpoint(str(save_path)) + + if self.progress_bar_callback: + self.progress_bar_callback.disable() + + # Initially we just double in size until an OOM is encountered + new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val + if mode == 'power': + new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials) + elif mode == 'binsearch': + new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials) + else: + raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') + + garbage_collection_cuda() + log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}') + + # Restore initial state of model + self.restore(str(save_path), on_gpu=self.on_gpu) + os.remove(save_path) + + # Finish by resetting variables so trainer is ready to fit model + self.__scale_batch_restore_params() + if self.progress_bar_callback: + self.progress_bar_callback.enable() + + return new_size + + def __scale_batch_dump_params(self): + # Prevent going into infinite loop + self.__dumped_params = { + 'max_steps': self.max_steps, + 'weights_summary': self.weights_summary, + 'logger': self.logger, + 'callbacks': self.callbacks, + 'checkpoint_callback': self.checkpoint_callback, + 'early_stop_callback': self.early_stop_callback, + 'enable_early_stop': self.enable_early_stop, + 'auto_scale_batch_size': self.auto_scale_batch_size, + 'train_percent_check': self.train_percent_check, + 'model': self.model, + } + + def __scale_batch_reset_params(self, model, steps_per_trial): + self.auto_scale_batch_size = None # prevent recursion + self.max_steps = steps_per_trial # take few steps + self.weights_summary = None # not needed before full run + self.logger = None # not needed before full run + self.callbacks = [] # not needed before full run + self.checkpoint_callback = False # required for saving + self.early_stop_callback = None + self.enable_early_stop = False + self.train_percent_check = 1.0 + self.optimizers, self.schedulers = [], [] # required for saving + self.model = model # required for saving + + def __scale_batch_restore_params(self): + self.max_steps = self.__dumped_params['max_steps'] + self.weights_summary = self.__dumped_params['weights_summary'] + self.logger = self.__dumped_params['logger'] + self.callbacks = self.__dumped_params['callbacks'] + self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] + self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size'] + self.early_stop_callback = self.__dumped_params['early_stop_callback'] + self.enable_early_stop = self.__dumped_params['enable_early_stop'] + self.train_percent_check = self.__dumped_params['train_percent_check'] + self.model = self.__dumped_params['model'] + del self.__dumped_params + + +def _adjust_batch_size(trainer, + batch_arg_name: str = 'batch_size', + factor: float = 1.0, + value: Optional[int] = None, + desc: str = None): + """ Function for adjusting the batch size. It is expected that the user + has provided a model that has a hparam field called `batch_size` i.e. + `model.hparams.batch_size` should exist. + + Args: + trainer: instance of pytorch_lightning.Trainer + + batch_arg_name: field where batch_size is stored in `model.hparams` + + factor: value which the old batch size is multiplied by to get the + new batch size + + value: if a value is given, will override the batch size with this value. + Note that the value of `factor` will not have an effect in this case + + desc: either `succeeded` or `failed`. Used purely for logging + + """ + model = trainer.get_model() + batch_size = getattr(model.hparams, batch_arg_name) + if value: + setattr(model.hparams, batch_arg_name, value) + new_size = value + if desc: + log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + else: + new_size = int(batch_size * factor) + if desc: + log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + setattr(model.hparams, batch_arg_name, new_size) + return new_size + + +def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): + """ Batch scaling mode where the size is doubled at each iteration until an + OOM error is encountered. """ + for _ in range(max_trials): + garbage_collection_cuda() + trainer.global_step = 0 # reset after each try + try: + # Try fit + trainer.fit(model) + # Double in size + new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed') + break + else: + raise # some other error not memory related + return new_size + + +def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials): + """ Batch scaling mode where the size is initially is doubled at each iteration + until an OOM error is encountered. Hereafter, the batch size is further + refined using a binary search """ + high = None + count = 0 + while True: + garbage_collection_cuda() + trainer.global_step = 0 # reset after each try + try: + # Try fit + trainer.fit(model) + count += 1 + if count > max_trials: + break + # Double in size + low = new_size + if high: + if high - low <= 1: + break + midval = (high + low) // 2 + new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded') + else: + new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') + except RuntimeError as exception: + # Only these errors should trigger an adjustment + if is_oom_error(exception): + # If we fail in power mode, half the size and return + garbage_collection_cuda() + high = new_size + midval = (high + low) // 2 + new_size = _adjust_batch_size(trainer, value=midval, desc='failed') + if high - low <= 1: + break + else: + raise # some other error not memory related + return new_size diff --git a/pytorch_lightning/utilities/memory.py b/pytorch_lightning/utilities/memory.py index 148d28ca143ff..eed7a13ca98bf 100644 --- a/pytorch_lightning/utilities/memory.py +++ b/pytorch_lightning/utilities/memory.py @@ -1,3 +1,7 @@ +import gc +import torch + + def recursive_detach(in_dict: dict) -> dict: """Detach all tensors in `in_dict`. @@ -20,3 +24,34 @@ def recursive_detach(in_dict: dict) -> dict: else: out_dict.update({k: v}) return out_dict + + +def is_oom_error(exception): + return is_cuda_out_of_memory(exception) \ + or is_cudnn_snafu(exception) \ + or is_out_of_cpu_memory(exception) + + +def is_cuda_out_of_memory(exception): + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "CUDA out of memory." in exception.args[0] + + +def is_cudnn_snafu(exception): + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] + + +def is_out_of_cpu_memory(exception): + return isinstance(exception, RuntimeError) \ + and len(exception.args) == 1 \ + and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] + + +def garbage_collection_cuda(): + """Garbage collection Torch (CUDA) memory.""" + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py new file mode 100755 index 0000000000000..a8e2617b3f2ec --- /dev/null +++ b/tests/trainer/test_trainer_tricks.py @@ -0,0 +1,130 @@ +import pytest +import torch + +import tests.base.utils as tutils +from pytorch_lightning import Trainer +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.base import EvalModelTemplate + + +def test_model_reset_correctly(tmpdir): + """ Check that model weights are correctly reset after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate(tutils.get_default_hparams()) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1 + ) + + before_state_dict = model.state_dict() + + trainer.scale_batch_size(model, max_trials=5) + + after_state_dict = model.state_dict() + + for key in before_state_dict.keys(): + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), \ + 'Model was not reset correctly after scaling batch size' + + +def test_trainer_reset_correctly(tmpdir): + """ Check that all trainer parameters are reset correctly after scaling batch size. """ + tutils.reset_seed() + + model = EvalModelTemplate(tutils.get_default_hparams()) + + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1 + ) + + changed_attributes = ['max_steps', + 'weights_summary', + 'logger', + 'callbacks', + 'checkpoint_callback', + 'early_stop_callback', + 'enable_early_stop', + 'train_percent_check'] + + attributes_before = {} + for ca in changed_attributes: + attributes_before[ca] = getattr(trainer, ca) + + trainer.scale_batch_size(model, max_trials=5) + + attributes_after = {} + for ca in changed_attributes: + attributes_after[ca] = getattr(trainer, ca) + + for key in changed_attributes: + assert attributes_before[key] == attributes_after[key], \ + f'Attribute {key} was not reset correctly after learning rate finder' + + +@pytest.mark.parametrize('scale_arg', ['power', 'binsearch']) +def test_trainer_arg(tmpdir, scale_arg): + """ Check that trainer arg works with bool input. """ + tutils.reset_seed() + + hparams = tutils.get_default_hparams() + model = EvalModelTemplate(hparams) + + before_batch_size = hparams.batch_size + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + auto_scale_batch_size=scale_arg, + ) + + trainer.fit(model) + after_batch_size = model.hparams.batch_size + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + +@pytest.mark.parametrize('scale_method', ['power', 'binsearch']) +def test_call_to_trainer_method(tmpdir, scale_method): + """ Test that calling the trainer method itself works. """ + tutils.reset_seed() + + hparams = tutils.get_default_hparams() + model = EvalModelTemplate(hparams) + + before_batch_size = hparams.batch_size + # logger file to get meta + trainer = Trainer( + default_save_path=tmpdir, + max_epochs=1, + ) + + after_batch_size = trainer.scale_batch_size(model, mode=scale_method, max_trials=5) + model.hparams.batch_size = after_batch_size + trainer.fit(model) + + assert before_batch_size != after_batch_size, \ + 'Batch size was not altered after running auto scaling of batch size' + + +def test_error_on_dataloader_passed_to_fit(tmpdir): + """Verify that when the auto scale batch size feature raises an error + if a train dataloader is passed to fit """ + + # only train passed to fit + model = EvalModelTemplate(tutils.get_default_hparams()) + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + val_percent_check=0.1, + train_percent_check=0.2, + auto_scale_batch_size='power' + ) + fit_options = dict(train_dataloader=model.dataloader(train=True)) + + with pytest.raises(MisconfigurationException): + trainer.fit(model, **fit_options)