Skip to content

Commit

Permalink
Misleading exception raised during batch scaling (#1973)
Browse files Browse the repository at this point in the history
* Misleading exception raised during batch scaling

Use batch_size from `model.hparams.batch_size` instead of `model.batch_size`

* Improvements considering #1896

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
4 people committed Jun 17, 2020
1 parent e1f238a commit f8103f9
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ def scale_batch_size(self,
"""
if not hasattr(model, batch_arg_name):
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
if not hasattr(model.hparams, batch_arg_name):
raise MisconfigurationException(
'Neither of `model.batch_size` and `model.hparams.batch_size` found.'
)

if hasattr(model.train_dataloader, 'patch_loader_code'):
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
Expand Down Expand Up @@ -245,17 +248,23 @@ def _adjust_batch_size(trainer,
"""
model = trainer.get_model()
batch_size = getattr(model, batch_arg_name)
if hasattr(model, batch_arg_name):
batch_size = getattr(model, batch_arg_name)
else:
batch_size = getattr(model.hparams, batch_arg_name)
if value:
setattr(model, batch_arg_name, value)
if hasattr(model, batch_arg_name):
setattr(model, batch_arg_name, value)
else:
setattr(model.hparams, batch_arg_name, value)
new_size = value
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
else:
new_size = int(batch_size * factor)
if desc:
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
setattr(model, batch_arg_name, new_size)
setattr(model.hparams, batch_arg_name, new_size)
return new_size


Expand Down

0 comments on commit f8103f9

Please sign in to comment.