diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3d37f342a43ba..361faa9715a89 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1,7 +1,7 @@ import inspect import os import logging as python_logging -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from typing import Union, Optional, List, Dict, Tuple, Iterable, Any import torch @@ -132,7 +132,7 @@ def __init__( replace_sampler_ddp: bool = True, progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, terminate_on_nan: bool = False, - auto_scale_batch_size: Optional[str] = None, + auto_scale_batch_size: Union[str, bool] = False, amp_level: str = 'O1', # backward compatible, todo: remove in v0.8.0 default_save_path=None, # backward compatible, todo: remove in v0.8.0 gradient_clip=None, # backward compatible, todo: remove in v0.8.0 @@ -663,52 +663,70 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: # TODO: get "help" from docstring :) for arg, arg_types, arg_default in (at for at in cls.get_init_arguments_and_types() if at[0] not in depr_arg_names): - - for allowed_type in (at for at in allowed_types if at in arg_types): - if allowed_type is bool: - def allowed_type(x): + arg_types = [at for at in allowed_types if at in arg_types] + if not arg_types: + # skip argument with not supported type + continue + arg_kwargs = {} + if bool in arg_types: + arg_kwargs.update(nargs="?") + # if the only arg type is bool + if len(arg_types) == 1: + # redefine the type for ArgParser needed + def use_type(x): return bool(parsing.strtobool(x)) - - # Bool args with default of True parsed as flags not key value pair - if arg_types == (bool,) and arg_default is False: - parser.add_argument( - f'--{arg}', - action='store_true', - dest=arg, - help='autogenerated by pl.Trainer' - ) - continue - - if arg == 'gpus': - allowed_type = Trainer.allowed_type - arg_default = Trainer.arg_default - - parser.add_argument( - f'--{arg}', - default=arg_default, - type=allowed_type, - dest=arg, - help='autogenerated by pl.Trainer' - ) - break + else: + # filter out the bool as we need to use more general + use_type = [at for at in arg_types if at is not bool][0] + else: + use_type = arg_types[0] + + if arg == 'gpus': + use_type = Trainer._allowed_type + arg_default = Trainer._arg_default + + parser.add_argument( + f'--{arg}', + dest=arg, + default=arg_default, + type=use_type, + help='autogenerated by pl.Trainer', + **arg_kwargs, + ) return parser - def allowed_type(x): + def _allowed_type(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) - def arg_default(x): + def _arg_default(x) -> Union[int, str]: if ',' in x: return str(x) else: return int(x) + @staticmethod + def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace: + """Parse CLI arguments, required for custom bool types.""" + args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser + args = {k: True if v is None else v for k, v in vars(args).items()} + return Namespace(**args) + @classmethod - def from_argparse_args(cls, args, **kwargs): + def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer': + """create an instance from CLI arguments + Example: + >>> parser = ArgumentParser(add_help=False) + >>> parser = Trainer.add_argparse_args(parser) + >>> args = Trainer.parse_argparser(parser.parse_args("")) + >>> trainer = Trainer.from_argparse_args(args) + """ + if isinstance(args, ArgumentParser): + args = Trainer.parse_argparser(args) params = vars(args) params.update(**kwargs) @@ -797,6 +815,8 @@ def fit( # Run auto batch size scaling if self.auto_scale_batch_size: + if isinstance(self.auto_scale_batch_size, bool): + self.auto_scale_batch_size = 'power' self.scale_batch_size(model, mode=self.auto_scale_batch_size) # Run learning rate finder: diff --git a/tests/trainer/test_trainer_cli.py b/tests/trainer/test_trainer_cli.py index 9b0e64dfefb7a..922acf9352832 100644 --- a/tests/trainer/test_trainer_cli.py +++ b/tests/trainer/test_trainer_cli.py @@ -88,3 +88,25 @@ def _raise(): with pytest.raises(_UnkArgError): parser.parse_args(cli_args) + + +# todo: add also testing for "gpus" +@pytest.mark.parametrize(['cli_args', 'expected'], [ + pytest.param('--auto_lr_find --auto_scale_batch_size power', + {'auto_lr_find': True, 'auto_scale_batch_size': 'power', 'early_stop_callback': False}), + pytest.param('--auto_lr_find any_string --auto_scale_batch_size', + {'auto_lr_find': 'any_string', 'auto_scale_batch_size': True}), + pytest.param('--early_stop_callback', + {'auto_lr_find': False, 'early_stop_callback': True, 'auto_scale_batch_size': False}), +]) +def test_argparse_args_parsing(cli_args, expected): + """Test multi type argument with bool.""" + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parent_parser=parser) + args = Trainer.parse_argparser(parser) + + for k, v in expected.items(): + assert getattr(args, k) == v + assert Trainer.from_argparse_args(args)