From 48c22c8bad9a47141c7160d92f2edc9e2e4ad159 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 3 Sep 2020 22:07:49 +0200 Subject: [PATCH] update batch size in DataModule when auto scaling batch size (#3266) * fix datamodule hasattr * fix patch check * fix setattr * update docs * revert patch fix * changelog * fix datamodule passed in as fit arg * docs * set datamodule batch size in lightning_setattr * fix merge * check with has_attr * access datamodule via trainer * pass fit args down to tuner * docs * fix typos in docs Co-authored-by: Rohit Gupta Co-authored-by: Rohit Gupta --- CHANGELOG.md | 2 + pytorch_lightning/trainer/trainer.py | 8 +++- pytorch_lightning/trainer/training_tricks.py | 37 +++++++++--------- pytorch_lightning/utilities/parsing.py | 41 +++++++++++++++----- tests/trainer/test_trainer_tricks.py | 11 +++++- 5 files changed, 69 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a39ad11f69d59..192da88a41085 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008)) +- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266)) + ## [0.9.0] - YYYY-MM-DD ### Added diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3554a9ae786a0..2a0bd00886d31 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -951,7 +951,13 @@ def tune( if self.auto_scale_batch_size: if isinstance(self.auto_scale_batch_size, bool): self.auto_scale_batch_size = 'power' - self.scale_batch_size(model, mode=self.auto_scale_batch_size) + self.scale_batch_size( + model, + mode=self.auto_scale_batch_size, + train_dataloader=train_dataloader, + val_dataloaders=val_dataloaders, + datamodule=datamodule, + ) model.logger = self.logger # reset logger binding # Run learning rate finder: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 32d0c59434c7a..705abc6343d49 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -134,7 +134,8 @@ def scale_batch_size(self, steps_per_trial: int = 3, init_val: int = 2, max_trials: int = 25, - batch_arg_name: str = 'batch_size'): + batch_arg_name: str = 'batch_size', + **fit_kwargs): r""" Will iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error. @@ -158,6 +159,10 @@ def scale_batch_size(self, max_trials: max number of increase in batch size done before algorithm is terminated + batch_arg_name: name of the attribute that stores the batch size. + + **fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader + or datamodule. """ if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException( @@ -190,9 +195,9 @@ def scale_batch_size(self, # Initially we just double in size until an OOM is encountered new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val if mode == 'power': - new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials) + new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs) elif mode == 'binsearch': - new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials) + new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs) else: raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch') @@ -259,7 +264,9 @@ def _adjust_batch_size(trainer, desc: str = None): """ Function for adjusting the batch size. It is expected that the user has provided a model that has a hparam field called `batch_size` i.e. - `model.hparams.batch_size` should exist. + `model.hparams.batch_size` should exist. Additionally there can be a + datamodule attached to either Trainer or model, in that case the attribute + also gets updated when present. Args: trainer: instance of pytorch_lightning.Trainer @@ -277,20 +284,14 @@ def _adjust_batch_size(trainer, """ model = trainer.get_model() batch_size = lightning_getattr(model, batch_arg_name) - if value: - lightning_setattr(model, batch_arg_name, value) - new_size = value - if desc: - log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - else: - new_size = int(batch_size * factor) - if desc: - log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') - lightning_setattr(model, batch_arg_name, new_size) + new_size = value if value is not None else int(batch_size * factor) + if desc: + log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}') + lightning_setattr(model, batch_arg_name, new_size) return new_size -def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): +def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """ for _ in range(max_trials): @@ -298,7 +299,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model) + trainer.fit(model, **fit_kwargs) # Double in size new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded') except RuntimeError as exception: @@ -313,7 +314,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials): return new_size -def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials): +def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs): """ Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. Hereafter, the batch size is further refined using a binary search """ @@ -324,7 +325,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials) trainer.global_step = 0 # reset after each try try: # Try fit - trainer.fit(model) + trainer.fit(model, **fit_kwargs) count += 1 if count > max_trials: break diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 6f57e63e48fc9..dab1127579b87 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -175,8 +175,10 @@ def __repr__(self): def lightning_hasattr(model, attribute): - """ Special hasattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + """ Special hasattr for lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. """ + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): attr = True @@ -186,6 +188,9 @@ def lightning_hasattr(model, attribute): attr = attribute in model.hparams else: attr = hasattr(model.hparams, attribute) + # Check if the attribute in datamodule (datamodule gets registered in Trainer) + elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + attr = getattr(trainer.datamodule, attribute) else: attr = False @@ -193,8 +198,10 @@ def lightning_hasattr(model, attribute): def lightning_getattr(model, attribute): - """ Special getattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + """ Special getattr for lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. """ + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): attr = getattr(model, attribute) @@ -204,24 +211,38 @@ def lightning_getattr(model, attribute): attr = model.hparams[attribute] else: attr = getattr(model.hparams, attribute) + + # Check if the attribute in datamodule (datamodule gets registered in Trainer) + elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + attr = getattr(trainer.datamodule, attribute) else: - raise ValueError(f'{attribute} is not stored in the model namespace' - ' or the `hparams` namespace/dict.') + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') return attr def lightning_setattr(model, attribute, value): """ Special setattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict """ + and the old hparams namespace/dict. + Will also set the attribute on datamodule, if it exists. + """ + if not lightning_hasattr(model, attribute): + raise ValueError(f'{attribute} is neither stored in the model namespace' + ' nor the `hparams` namespace/dict, nor the datamodule.') + + trainer = model.trainer + # Check if attribute in model if hasattr(model, attribute): setattr(model, attribute, value) + # Check if attribute in model.hparams, either namespace or dict elif hasattr(model, 'hparams'): if isinstance(model.hparams, dict): model.hparams[attribute] = value else: setattr(model.hparams, attribute, value) - else: - raise ValueError(f'{attribute} is not stored in the model namespace' - ' or the `hparams` namespace/dict.') + + # Check if the attribute in datamodule (datamodule gets registered in Trainer) + if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute): + setattr(trainer.datamodule, attribute, value) diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index 75dd3beab85db..85121f5946c20 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -8,6 +8,7 @@ from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate +from tests.base.datamodules import MNISTDataModule def test_num_training_batches(tmpdir): @@ -228,13 +229,21 @@ def dataloader(self, *args, **kwargs): del self.batch_size return dataloader + datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored! + datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size) + model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate model = model_class(**hparams) + model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True) - trainer.tune(model) + trainer.tune(model, datamodule_fit) + assert trainer.datamodule == datamodule_fit after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size assert before_batch_size != after_batch_size + assert datamodule_fit.batch_size == after_batch_size + # should be left unchanged, since it was not passed to .tune() + assert datamodule_model.batch_size == 111 def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):