Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Sep 6, 2020
1 parent 0c2a5a2 commit 00e73c1
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ def scale_batch_size(self,
algorithm is terminated
batch_arg_name: name of the attribute that stores the batch size.
It is expected that the user has provided a model or datamodule that has a hyperparameter
with that name. We will look for this attribute name in the following places
- `model`
- `model.hparams`
- `model.datamodule`
- `trainer.datamodule` (the datamodule passed to the tune method)
**fit_kwargs: remaining arguments to be passed to .fit(), e.g., dataloader
or datamodule.
Expand Down Expand Up @@ -263,16 +270,12 @@ def _adjust_batch_size(trainer,
factor: float = 1.0,
value: Optional[int] = None,
desc: str = None) -> Tuple[int, bool]:
""" Function for adjusting the batch size. It is expected that the user
has provided a model that has a hparam field called `batch_size` i.e.
`model.hparams.batch_size` should exist. Additionally there can be a
datamodule attached to either Trainer or model, in that case the attribute
also gets updated when present.
""" Helper function for adjusting the batch size.
Args:
trainer: instance of pytorch_lightning.Trainer
batch_arg_name: field where batch_size is stored in `model.hparams`
batch_arg_name: name of the field where batch_size is stored.
factor: value which the old batch size is multiplied by to get the
new batch size
Expand Down

0 comments on commit 00e73c1

Please sign in to comment.