Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update batch size in LightningModule.datamodule when auto scaling batch size #3266

Merged
merged 16 commits into from
Sep 3, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008))

- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,13 @@ def tune(
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)
self.scale_batch_size(
model,
mode=self.auto_scale_batch_size,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
datamodule=datamodule,
)
model.logger = self.logger # reset logger binding

# Run learning rate finder:
Expand Down
37 changes: 19 additions & 18 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def scale_batch_size(self,
steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
batch_arg_name: str = 'batch_size'):
batch_arg_name: str = 'batch_size',
**fit_kwargs):
r"""
Will iteratively try to find the largest batch size for a given model
that does not give an out of memory (OOM) error.
Expand All @@ -158,6 +159,10 @@ def scale_batch_size(self,
max_trials: max number of increase in batch size done before
algorithm is terminated

batch_arg_name: name of the attribute that stores the batch size.

**fit_kwargs: remaining arguments to be passed to .fit() when, e.g., dataloader
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
or datamodule.
"""
if not lightning_hasattr(model, batch_arg_name):
raise MisconfigurationException(
Expand Down Expand Up @@ -190,9 +195,9 @@ def scale_batch_size(self,
# Initially we just double in size until an OOM is encountered
new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val
if mode == 'power':
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials)
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
elif mode == 'binsearch':
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials)
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials, **fit_kwargs)
else:
raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')

Expand Down Expand Up @@ -259,7 +264,9 @@ def _adjust_batch_size(trainer,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in which case the attribute
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
also gets updated when present.

Args:
trainer: instance of pytorch_lightning.Trainer
Expand All @@ -277,28 +284,22 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
batch_size = lightning_getattr(model, batch_arg_name)
if value:
lightning_setattr(model, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
else:
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
new_size = value if value is not None else int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
return new_size


def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
""" Batch scaling mode where the size is doubled at each iteration until an
OOM error is encountered. """
for _ in range(max_trials):
garbage_collection_cuda()
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
trainer.fit(model, **fit_kwargs)
# Double in size
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
except RuntimeError as exception:
Expand All @@ -313,7 +314,7 @@ def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
return new_size


def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials):
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials, **fit_kwargs):
""" Batch scaling mode where the size is initially is doubled at each iteration
until an OOM error is encountered. Hereafter, the batch size is further
refined using a binary search """
Expand All @@ -324,7 +325,7 @@ def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)
trainer.global_step = 0 # reset after each try
try:
# Try fit
trainer.fit(model)
trainer.fit(model, **fit_kwargs)
count += 1
if count > max_trials:
break
Expand Down
41 changes: 31 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,10 @@ def __repr__(self):


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
attr = True
Expand All @@ -186,15 +188,20 @@ def lightning_hasattr(model, attribute):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
# Check if attribute in datamodule (datamodule gets registerd in Trainer)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
attr = getattr(trainer.datamodule, attribute)
else:
attr = False

return attr


def lightning_getattr(model, attribute):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
Expand All @@ -204,24 +211,38 @@ def lightning_getattr(model, attribute):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)

# Check if attribute in datamodule (datamodule gets registerd in Trainer)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
elif trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
attr = getattr(trainer.datamodule, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
if not lightning_hasattr(model, attribute):
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')

trainer = model.trainer

# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)

# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')

# Check if attribute in datamodule (datamodule gets registerd in Trainer)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
setattr(trainer.datamodule, attribute, value)
11 changes: 10 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule


def test_num_training_batches(tmpdir):
Expand Down Expand Up @@ -228,13 +229,21 @@ def dataloader(self, *args, **kwargs):
del self.batch_size
return dataloader

datamodule_model = MNISTDataModule(data_dir=tmpdir, batch_size=111) # this datamodule should get ignored!
datamodule_fit = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size)

model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
model = model_class(**hparams)
model.datamodule = datamodule_model # unused when another module gets passed to .tune() / .fit()

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.tune(model)
trainer.tune(model, datamodule_fit)
assert trainer.datamodule == datamodule_fit
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert datamodule_fit.batch_size == after_batch_size
# should be left unchanged, since it was not passed to .tune()
assert datamodule_model.batch_size == 111


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down