Skip to content

Commit

Permalink
extend arg parser (#1842)
Browse files Browse the repository at this point in the history
* extend arg parser

* flake8

* tests

* example

* fix test
  • Loading branch information
Borda committed May 14, 2020
1 parent a6f6edd commit bee0392
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
84 changes: 52 additions & 32 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import os
import logging as python_logging
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any

import torch
Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(
replace_sampler_ddp: bool = True,
progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Optional[str] = None,
auto_scale_batch_size: Union[str, bool] = False,
amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0
default_save_path=None, # backward compatible, todo: remove in v0.8.0
gradient_clip=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -663,52 +663,70 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
# TODO: get "help" from docstring :)
for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types()
if at[0] not in depr_arg_names):

for allowed_type in (at for at in allowed_types if at in arg_types):
if allowed_type is bool:
def allowed_type(x):
arg_types = [at for at in allowed_types if at in arg_types]
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
if bool in arg_types:
arg_kwargs.update(nargs="?")
# if the only arg type is bool
if len(arg_types) == 1:
# redefine the type for ArgParser needed
def use_type(x):
return bool(parsing.strtobool(x))

# Bool args with default of True parsed as flags not key value pair
if arg_types == (bool,) and arg_default is False:
parser.add_argument(
f'--{arg}',
action='store_true',
dest=arg,
help='autogenerated by pl.Trainer'
)
continue

if arg == 'gpus':
allowed_type = Trainer.allowed_type
arg_default = Trainer.arg_default

parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break
else:
# filter out the bool as we need to use more general
use_type = [at for at in arg_types if at is not bool][0]
else:
use_type = arg_types[0]

if arg == 'gpus':
use_type = Trainer._allowed_type
arg_default = Trainer._arg_default

parser.add_argument(
f'--{arg}',
dest=arg,
default=arg_default,
type=use_type,
help='autogenerated by pl.Trainer',
**arg_kwargs,
)

return parser

def allowed_type(x):
def _allowed_type(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

def arg_default(x):
def _arg_default(x) -> Union[int, str]:
if ',' in x:
return str(x)
else:
return int(x)

@staticmethod
def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser
args = {k: True if v is None else v for k, v in vars(args).items()}
return Namespace(**args)

@classmethod
def from_argparse_args(cls, args, **kwargs):
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
"""create an instance from CLI arguments
Example:
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args)
"""
if isinstance(args, ArgumentParser):
args = Trainer.parse_argparser(args)
params = vars(args)
params.update(**kwargs)

Expand Down Expand Up @@ -797,6 +815,8 @@ def fit(

# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.scale_batch_size(model, mode=self.auto_scale_batch_size)

# Run learning rate finder:
Expand Down
22 changes: 22 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,25 @@ def _raise():

with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)


# todo: add also testing for "gpus"
@pytest.mark.parametrize(['cli_args', 'expected'], [
pytest.param('--auto_lr_find --auto_scale_batch_size power',
{'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}),
pytest.param('--auto_lr_find any_string --auto_scale_batch_size',
{'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}),
pytest.param('--early_stop_callback',
{'auto_lr_find': False, 'early_stop_callback': True, 'auto_scale_batch_size': False}),
])
def test_argparse_args_parsing(cli_args, expected):
"""Test multi type argument with bool."""
cli_args = cli_args.split(' ') if cli_args else []
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)
args = Trainer.parse_argparser(parser)

for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)

0 comments on commit bee0392

Please sign in to comment.