Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up Argparse interface with trainer #1606

Merged
merged 15 commits into from
Apr 26, 2020
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixes automatic parser bug ([#1585](https://github.com/PyTorchLightning/pytorch-lightning/issues/1585))

- Fixed bool conversion from string ([#1606](https://github.com/PyTorchLightning/pytorch-lightning/issues/1606))

## [0.7.3] - 2020-04-09

### Added
Expand Down
183 changes: 109 additions & 74 deletions docs/source/hyperparameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,36 +3,111 @@ Hyperparameters
Lightning has utilities to interact seamlessly with the command line ArgumentParser
and plays well with the hyperparameter optimization framework of your choice.

LightiningModule hparams
ArgumentParser
^^^^^^^^^^^^^^
Lightning is designed to augment a lot of the functionality of the built-in Python ArgumentParser

.. code-block:: python

from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument('--layer_1_dim', type=int, default=128)

args = parser.parse_args()

This allows you to call your program like so:

.. code-block:: bash

python trainer.py --layer_1_dim 64


Argparser Best Practices
^^^^^^^^^^^^^^^^^^^^^^^^
It is best practice to layer your arguments in three sections.

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network. The `Trainer` can add all the available options to an ArgumentParser.
1. Trainer args (gpus, num_nodes, etc...)
2. Model specific arguments (layer_dim, num_layers, learning_rate, etc...)
3. Program arguments (data_path, cluster_email, etc...)

We can do this as follows. First, in your LightningModule, define the arguments
specific to that module. Remember that data splits or data paths may also be specific to
a module (ie: if your project has a model that trains on Imagenet and another on CIFAR-10).

.. code-block:: python

class LitModel(LightningModule):

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--encoder_layers', type=int, default=12)
parser.add_argument('--data_path', type=str, default='/some/path')
return parser

Now in your main trainer file, add the Trainer args, the program args, and add the model args

.. code-block:: python

# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser

parser = ArgumentParser()

# parametrize the network
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_2_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
# add PROGRAM level args
parser.add_argument('--conda_env', type=str, default='some_name')
parser.add_argument('--notification_email', type=str, default='will@email.com')

# add model specific args
parser = LitModel.add_model_specific_args(parser)

# add all the available options to the trainer
# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = pl.Trainer.add_argparse_args(parser)

args = parser.parse_args()
hparams = parser.parse_args()

Now we can parametrize the LightningModule.
Now you can call run your program like so

.. code-block:: bash

python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12

Finally, make sure to start the training like so:

.. code-block:: bash

hparams = parser.parse_args()

# YES
model = LitModel(hparams)

# NO
# model = LitModel(learning_rate=hparams.learning_rate, ...)

# YES
trainer = Trainer.from_argparse_args(hparams, early_stopping_callback=...)

# NO
trainer = Trainer(gpus=hparams.gpus, ...)


LightiningModule hparams
^^^^^^^^^^^^^^^^^^^^^^^^

Normally, we don't hard-code the values to a model. We usually use the command line to
modify the network and read those values in the LightningModule

.. code-block:: python
:emphasize-lines: 5,6,7,12,14

class LitMNIST(pl.LightningModule):
def __init__(self, hparams):
super().__init__()

# do this to save all arguments in any logger (tensorboard)
self.hparams = hparams

self.layer_1 = torch.nn.Linear(28 * 28, hparams.layer_1_dim)
Expand All @@ -49,86 +124,44 @@ Now we can parametrize the LightningModule.
def configure_optimizers(self):
return Adam(self.parameters(), lr=self.hparams.learning_rate)

hparams = parse_args()
model = LitMNIST(hparams)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)

.. note:: Bonus! if (hparams) is in your module, Lightning will save it into the checkpoint and restore your
model using those hparams exactly.
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_2_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--learning_rate', type=float, default=0.002)
return parser

And we can also add all the flags available in the Trainer to the Argparser.
Now pass in the params when you init your model

.. code-block:: python

# add all the available Trainer options to the ArgParser
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()

And now you can start your program with
hparams = parse_args()
model = LitMNIST(hparams)

.. code-block:: bash
The line `self.hparams = hparams` is very special. This line assigns your hparams to the LightningModule.
This does two things:

# now you can use any trainer flag
$ python main.py --num_nodes 2 --gpus 8
1. It adds them automatically to tensorboard logs under the hparams tab.
2. Lightning will save those hparams to the checkpoint and use them to restore the module correctly.

Trainer args
^^^^^^^^^^^^

It also gets annoying to map each argument into the Argparser. Luckily we have
a default parser
To recap, add ALL possible trainer flags to the argparser and init the Trainer this way

.. code-block:: python

parser = ArgumentParser()

# add all options available in the trainer such as (max_epochs, etc...)
parser = Trainer.add_argparse_args(parser)
hparams = parser.parse_args()

We set up the main training entry point file like this:

.. code-block:: python

def main(args):
model = LitMNIST(hparams=args)
trainer = Trainer(max_epochs=args.max_epochs)
trainer.fit(model)
trainer = Trainer.from_argparse_args(hparams)

if __name__ == '__main__':
parser = ArgumentParser()
# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])

# adds all the trainer options as default arguments (like max_epochs)
parser = Trainer.add_argparse_args(parser)

# parametrize the network
parser.add_argument('--layer_1_dim', type=int, default=128)
parser.add_argument('--layer_1_dim', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=64)
args = parser.parse_args()

# train
main(args)

And now we can train like this:

.. code-block:: bash

$ python main.py --layer_1_dim 128 --layer_2_dim 256 --batch_size 64 --max_epochs 64

But it would also be nice to pass in any arbitrary argument to the trainer.
We can do it by changing how we init the trainer.

.. code-block:: python

def main(args):
model = LitMNIST(hparams=args)

# makes all trainer options available from the command line
trainer = Trainer.from_argparse_args(args)

and now we can do this:

.. code-block:: bash

$ python main.py --gpus 1 --min_epochs 12 --max_epochs 64 --arbitrary_trainer_arg some_value

Multiple Lightning Modules
^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -173,7 +206,7 @@ Now we can allow each model to inject the arguments it needs in the main.py
model = LitMNIST(hparams=args)

model = LitMNIST(hparams=args)
trainer = Trainer(max_epochs=args.max_epochs)
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)

if __name__ == '__main__':
Expand All @@ -182,6 +215,8 @@ Now we can allow each model to inject the arguments it needs in the main.py

# figure out which model to use
parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')

# THIS LINE IS KEY TO PULL THE MODEL NAME
temp_args = parser.parse_known_args()

# let the model add what it wants
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LightningLoggerBase, LoggerCollection
from pytorch_lightning.utilities import memory_utils
from pytorch_lightning.utilities.memory import recursive_detach


class TrainerLoggingMixin(ABC):
Expand Down Expand Up @@ -174,7 +174,7 @@ def process_output(self, output, train=False):

# detach all metrics for callbacks to prevent memory leaks
# no .item() because it will slow things down
callback_metrics = memory_utils.recursive_detach(callback_metrics)
callback_metrics = recursive_detach(callback_metrics)

return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens

Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import distutils
import inspect
import os
from argparse import ArgumentParser
Expand Down Expand Up @@ -33,6 +32,7 @@
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import parsing


try:
Expand Down Expand Up @@ -599,17 +599,19 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )

depr_arg_names = cls.get_deprecated_arg_names()
blacklist = ['kwargs']
depr_arg_names = cls.get_deprecated_arg_names() + blacklist

allowed_types = (str, float, int, bool)

# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):

for allowed_type in (at for at in allowed_types if at in arg_types):
if allowed_type is bool:
def allowed_type(x):
return bool(distutils.util.strtobool(x))
return bool(parsing.strtobool(x))

if arg == 'gpus':
def allowed_type(x):
Expand All @@ -636,9 +638,11 @@ def arg_default(x):
return parser

@classmethod
def from_argparse_args(cls, args):
def from_argparse_args(cls, args, **kwargs):

params = vars(args)
params.update(**kwargs)

return cls(**params)

@property
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def training_step(self, batch, batch_idx):

"""

import copy
from abc import ABC, abstractmethod
from typing import Callable
from typing import Union, List
Expand All @@ -154,11 +153,9 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import memory_utils

try:
from apex import amp
Expand Down
20 changes: 20 additions & 0 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
def strtobool(val):
"""Convert a string representation of truth to true (1) or false (0).
Copied from the python implementation distutils.utils.strtobool

True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if
'val' is anything else.

>>> strtobool('YES')
1
>>> strtobool('FALSE')
0
"""
val = val.lower()
if val in ('y', 'yes', 't', 'true', 'on', '1'):
return 1
elif val in ('n', 'no', 'f', 'false', 'off', '0'):
return 0
else:
raise ValueError(f'invalid truth value {val}')