Skip to content

Commit

Permalink
Clean up Argparse interface with trainer (#1606)
Browse files Browse the repository at this point in the history
* fixed distutil parsing

* fixed distutil parsing

* Apply suggestions from code review

* log

* fixed distutil parsing

* fixed distutil parsing

* fixed distutil parsing

* fixed distutil parsing

* doctest

* fixed hparams section

* fixed hparams section

* fixed hparams section

* formatting

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people committed Apr 26, 2020
1 parent 13bf772 commit 4755ded
Show file tree
Hide file tree
Showing 7 changed files with 141 additions and 83 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixes automatic parser bug ([#1585](https://github.com/PyTorchLightning/pytorch-lightning/issues/1585))

- Fixed bool conversion from string ([#1606](https://github.com/PyTorchLightning/pytorch-lightning/issues/1606))

## [0.7.3] - 2020-04-09

### Added
Expand Down
183 changes: 109 additions & 74 deletions docs/source/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,111 @@ Hyperparameters
Lightning has utilities to interact seamlessly with the command line ArgumentParser
and plays well with the hyperparameter optimization framework of your choice.

LightiningModule hparams
ArgumentParser
^^^^^^^^^^^^^^
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser

.. code-block:: python
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--layer_1_dim', type=int, default=128)
args = parser.parse_args()
This allows you to call your program like so:

.. code-block:: bash
python trainer.py --layer_1_dim 64
Argparser Best Practices
^^^^^^^^^^^^^^^^^^^^^^^^
It is best practice to layer your arguments in three sections.

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network. The `Trainer` can add all the available options to an ArgumentParser.
1. Trainer args (gpus, num_nodes, etc...)
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...)
3. Program arguments (data_path, cluster_email, etc...)

We can do this as follows. First, in your LightningModule, define the arguments
specific to that module. Remember that data splits or data paths may also be specific to
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10).

.. code-block:: python
class LitModel(LightningModule):
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--encoder_layers', type=int, default=12)
parser.add_argument('--data_path', type=str, default='/some/path')
return parser
Now in your main trainer file, add the Trainer args, the program args, and add the model args

.. code-block:: python
# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser
parser = ArgumentParser()
# parametrize the network
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_2_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
# add PROGRAM level args
parser.add_argument('--conda_env', type=str, default='some_name')
parser.add_argument('--notification_email', type=str, default='will@email.com')
# add model specific args
parser = LitModel.add_model_specific_args(parser)
# add all the available options to the trainer
# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
hparams = parser.parse_args()
Now we can parametrize the LightningModule.
Now you can call run your program like so

.. code-block:: bash
python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12
Finally, make sure to start the training like so:

.. code-block:: bash
hparams = parser.parse_args()
# YES
model = LitModel(hparams)
# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)
# YES
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)
# NO
trainer = Trainer(gpus=hparams.gpus, ...)
LightiningModule hparams
^^^^^^^^^^^^^^^^^^^^^^^^

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network and read those values in the LightningModule

.. code-block:: python
:emphasize-lines: 5,6,7,12,14
class LitMNIST(pl.LightningModule):
def __init__(self, hparams):
super().__init__()
# do this to save all arguments in any logger (tensorboard)
self.hparams = hparams
self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
Expand All @@ -49,86 +124,44 @@ Now we can parametrize the LightningModule.
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)
hparams = parse_args()
model = LitMNIST(hparams)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your
model using those hparams exactly.
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_2_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--learning_rate', type=float, default=0.002)
return parser
And we can also add all the flags available in the Trainer to the Argparser.
Now pass in the params when you init your model

.. code-block:: python
# add all the available Trainer options to the ArgParser
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
And now you can start your program with
hparams = parse_args()
model = LitMNIST(hparams)
.. code-block:: bash
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
This does two things:

# now you can use any trainer flag
$ python main.py --num_nodes 2 --gpus 8
1. It adds them automatically to tensorboard logs under the hparams tab.
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.

Trainer args
^^^^^^^^^^^^

It also gets annoying to map each argument into the Argparser. Luckily we have
a default parser
To recap, add ALL possible trainer flags to the argparser and init the Trainer this way

.. code-block:: python
parser = ArgumentParser()
# add all options available in the trainer such as (max_epochs, etc...)
parser = Trainer.add_argparse_args(parser)
hparams = parser.parse_args()
We set up the main training entry point file like this:

.. code-block:: python
def main(args):
model = LitMNIST(hparams=args)
trainer = Trainer(max_epochs=args.max_epochs)
trainer.fit(model)
trainer = Trainer.from_argparse_args(hparams)
if __name__ == '__main__':
parser = ArgumentParser()
# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
# adds all the trainer options as default arguments (like max_epochs)
parser = Trainer.add_argparse_args(parser)
# parametrize the network
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_1_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
args = parser.parse_args()
# train
main(args)
And now we can train like this:

.. code-block:: bash
$ python main.py --layer_1_dim 128 --layer_2_dim 256 --batch_size 64 --max_epochs 64
But it would also be nice to pass in any arbitrary argument to the trainer.
We can do it by changing how we init the trainer.

.. code-block:: python
def main(args):
model = LitMNIST(hparams=args)
# makes all trainer options available from the command line
trainer = Trainer.from_argparse_args(args)
and now we can do this:

.. code-block:: bash
$ python main.py --gpus 1 --min_epochs 12 --max_epochs 64 --arbitrary_trainer_arg some_value
Multiple Lightning Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -173,7 +206,7 @@ Now we can allow each model to inject the arguments it needs in the main.py
model = LitMNIST(hparams=args)
model = LitMNIST(hparams=args)
trainer = Trainer(max_epochs=args.max_epochs)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
if __name__ == '__main__':
Expand All @@ -182,6 +215,8 @@ Now we can allow each model to inject the arguments it needs in the main.py
# figure out which model to use
parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')
# THIS LINE IS KEY TO PULL THE MODEL NAME
temp_args = parser.parse_known_args()
# let the model add what it wants
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.utilities import memory_utils
from pytorch_lightning.utilities.memory import recursive_detach


class TrainerLoggingMixin(ABC):
Expand Down Expand Up @@ -174,7 +174,7 @@ def process_output(self, output, train=False):

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = memory_utils.recursive_detach(callback_metrics)
callback_metrics = recursive_detach(callback_metrics)

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens

Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import distutils
import inspect
import os
from argparse import ArgumentParser
Expand Down Expand Up @@ -33,6 +32,7 @@
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import parsing


try:
Expand Down Expand Up @@ -599,17 +599,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )

depr_arg_names = cls.get_deprecated_arg_names()
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, float, int, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):

for allowed_type in (at for at in allowed_types if at in arg_types):
if allowed_type is bool:
def allowed_type(x):
return bool(distutils.util.strtobool(x))
return bool(parsing.strtobool(x))

if arg == 'gpus':
def allowed_type(x):
Expand All @@ -636,9 +638,11 @@ def arg_default(x):
return parser

@classmethod
def from_argparse_args(cls, args):
def from_argparse_args(cls, args, **kwargs):

params = vars(args)
params.update(**kwargs)

return cls(**params)

@property
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def training_step(self, batch, batch_idx):
"""

import copy
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List
Expand All @@ -154,11 +153,9 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import memory_utils

try:
from apex import amp
Expand Down
File renamed without changes.
20 changes: 20 additions & 0 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
Copied from the python implementation distutils.utils.strtobool
True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.
>>> strtobool('YES')
1
>>> strtobool('FALSE')
0
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError(f'invalid truth value {val}')

0 comments on commit 4755ded

Please sign in to comment.