Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update batch size in LightningModule.datamodule when auto scaling batch size #3266

Merged
merged 16 commits into from
Sep 3, 2020
34 changes: 24 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def __repr__(self):


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
# Check if attribute in model
if hasattr(model, attribute):
attr = True
Expand All @@ -186,15 +186,17 @@ def lightning_hasattr(model, attribute):
attr = attribute in model.hparams
else:
attr = hasattr(model.hparams, attribute)
elif hasattr(model.datamodule, attribute):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing regarding the line above... mind adding a comment what is this case about...

attr = True
else:
attr = False

return attr


def lightning_getattr(model, attribute):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
""" Special getattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
# Check if attribute in model
if hasattr(model, attribute):
attr = getattr(model, attribute)
Expand All @@ -204,24 +206,36 @@ def lightning_getattr(model, attribute):
attr = model.hparams[attribute]
else:
attr = getattr(model.hparams, attribute)
elif hasattr(model.datamodule, attribute):
attr = getattr(model.datamodule, attribute)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
return attr


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict """
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
found = False
Borda marked this conversation as resolved.
Show resolved Hide resolved
# Check if attribute in model
if hasattr(model, attribute):
setattr(model, attribute, value)
found = True
# Check if attribute in model.hparams, either namespace or dict
elif hasattr(model, 'hparams'):
if isinstance(model.hparams, dict):
model.hparams[attribute] = value
else:
setattr(model.hparams, attribute, value)
else:
raise ValueError(f'{attribute} is not stored in the model namespace'
' or the `hparams` namespace/dict.')
found = True
# Check if attribute in datamodule
if hasattr(model.datamodule, attribute):
setattr(model.datamodule, attribute, value)
found = True

if not found:
raise ValueError(f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.')
4 changes: 4 additions & 0 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule


def test_num_training_batches(tmpdir):
Expand Down Expand Up @@ -230,11 +231,14 @@ def dataloader(self, *args, **kwargs):

model_class = HparamsEvalModelTemplate if use_hparams else EvalModelTemplate
model = model_class(**hparams)
model.datamodule = MNISTDataModule(data_dir=tmpdir, batch_size=before_batch_size)
model.datamodule.setup() # TODO: why do I have to call this myself?

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
trainer.fit(model)
after_batch_size = model.hparams.batch_size if use_hparams else model.batch_size
assert before_batch_size != after_batch_size
assert model.datamodule.batch_size == after_batch_size


def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
Expand Down