Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update batch size in LightningModule.datamodule when auto scaling batch size #3266

Merged
merged 16 commits into from
Sep 3, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `GpuUsageLogger` to work on different platforms ([#3008](https://github.com/PyTorchLightning/pytorch-lightning/pull/3008))

- Fixed setting batch size in `LightningModule.datamodule` when using `auto_scale_batch_size` ([#3266](https://github.com/PyTorchLightning/pytorch-lightning/pull/3266))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
20 changes: 9 additions & 11 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def _adjust_batch_size(trainer,
desc: str = None):
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in which case the attribute
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
also gets updated when present.

Args:
trainer: instance of pytorch_lightning.Trainer
Expand All @@ -277,16 +279,12 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
batch_size = lightning_getattr(model, batch_arg_name)
if value:
lightning_setattr(model, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
else:
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
new_size = value if value is not None else int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
lightning_setattr(model, batch_arg_name, new_size)
if trainer.datamodule is not None and hasattr(trainer.datamodule, batch_arg_name):
setattr(trainer.datamodule, batch_arg_name, new_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary, should lightning_setattr not take care of this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean accessing it through model.trainer.datamodule? Ok, I'll try that

return new_size


Expand Down
34 changes: 24 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def __repr__(self):


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
# Check if attribute in model
if hasattr(model, attribute):
attr = True
Expand All @@ -186,15 +186,17 @@ def lightning_hasattr(model, attribute):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
elif hasattr(model.datamodule, attribute):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing regarding the line above... mind adding a comment what is this case about...

attr = True
else:
attr = False

return attr


def lightning_getattr(model, attribute):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
Expand All @@ -204,24 +206,36 @@ def lightning_getattr(model, attribute):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)
elif hasattr(model.datamodule, attribute):
attr = getattr(model.datamodule, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
found = False
Borda marked this conversation as resolved.
Show resolved Hide resolved
# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)
found = True
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
found = True
# Check if attribute in datamodule
if hasattr(model.datamodule, attribute):
setattr(model.datamodule, attribute, value)
found = True

if not found:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
5 changes: 4 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule


def test_num_training_batches(tmpdir):
Expand Down Expand Up @@ -230,11 +231,13 @@ def dataloader(self, *args, **kwargs):

model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
model = model_class(**hparams)
datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.fit(model)
trainer.fit(model, datamodule)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert datamodule.batch_size == after_batch_size


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down