Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom argparser extension with Trainer arguments (argument types added) #1147

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ea6b8f1
`add_argparse_args` method fixed (argument types added)
Mar 14, 2020
28e9348
CHANGELOG.md upd
Mar 14, 2020
6c290dd
autopep8 fixes
Mar 14, 2020
4444701
--gpus=0 removed from test (for ci tests)
Mar 14, 2020
0ec101f
typo fixed
Mar 14, 2020
bc7dd6f
reduce on plateau scheduler fixed
Mar 14, 2020
57a22ea
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
2d220d5
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
5b5042e
test_get_init_arguments_and_types added
Mar 16, 2020
dd81a39
autopep8 fixes
Mar 16, 2020
b4d9605
Merge remote-tracking branch 'upstream/master'
Mar 16, 2020
977b2de
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
7a9ba50
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
551ff24
test_get_init_arguments_and_types added
Mar 16, 2020
d3c1fdc
autopep8 fixes
Mar 16, 2020
1f92d92
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 16, 2020
2b27e3c
merged
Mar 17, 2020
f501ce6
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
4bccfd7
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
38dcba7
test_get_init_arguments_and_types added
Mar 16, 2020
e1abf4d
autopep8 fixes
Mar 16, 2020
ea3334a
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
f79dcbf
test_get_init_arguments_and_types added
Mar 16, 2020
299574d
autopep8 fixes
Mar 16, 2020
4782fe6
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
20c4073
Apply suggestions from code review
Borda Mar 17, 2020
201d855
cosmetics
Mar 17, 2020
567d9b8
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
0c292b4
cosmetics
Mar 17, 2020
f45fa6f
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 17, 2020
eb23637
`Trainer.get_init_arguments_and_types` now returns arg types wrapped …
Mar 17, 2020
5f4a5fe
deprecated args are now ignored in argparser
Mar 17, 2020
f48f17e
get_deprecated_arg_names small refactor
Mar 17, 2020
2207c4e
get_deprecated_arg_names bug fixed
Mar 17, 2020
0a73a34
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
0cd9aaa
refactored: get_init_arguments_and_types is a public classmethod of t…
Mar 16, 2020
25bd99d
test_get_init_arguments_and_types added
Mar 16, 2020
5253a50
autopep8 fixes
Mar 16, 2020
3fd9a99
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
36ca20b
autopep8 fixes
Mar 16, 2020
f7cc4e1
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
ce837fe
autopep8 fixes
Mar 16, 2020
8aa062a
Trainer cli related tests moved to test_trainer_cli.py
Mar 16, 2020
d447aa3
test_get_init_arguments_and_types added
Mar 16, 2020
aaada85
autopep8 fixes
Mar 16, 2020
87b6cbb
Apply suggestions from code review
Borda Mar 17, 2020
a9f996b
cosmetics
Mar 17, 2020
a4702ff
cosmetics
Mar 17, 2020
162f3ab
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 17, 2020
a8b21d9
`Trainer.get_init_arguments_and_types` now returns arg types wrapped …
Mar 17, 2020
2a5d46d
deprecated args are now ignored in argparser
Mar 17, 2020
3b7c865
get_deprecated_arg_names small refactor
Mar 17, 2020
d07870d
get_deprecated_arg_names bug fixed
Mar 17, 2020
012d425
Merge remote-tracking branch 'origin/trainer-argparser-types' into tr…
Mar 17, 2020
9449b98
Merge branch 'master' into trainer-argparser-types
Borda Mar 18, 2020
94c5e09
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 20, 2020
4f8314d
Update pytorch_lightning/trainer/trainer.py
alexeykarnachev Mar 20, 2020
77847b7
Merge branch 'master' into trainer-argparser-types
Borda Mar 20, 2020
63d30fd
Merge branch 'master' into trainer-argparser-types
williamFalcon Mar 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).

## [0.7.1] - 2020-03-07

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UNet(nn.Module):
bilinear (bool) - Whether to use bilinear interpolation or transposed
convolutions for upsampling.
'''

def __init__(self, num_classes=19, bilinear=False):
super().__init__()
self.layer1 = DoubleConv(3, 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class DoubleConv(nn.Module):
Double Convolution and BN and ReLU
(3x3 conv -> BN -> ReLU) ** 2
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -27,6 +28,7 @@ class Down(nn.Module):
'''
Combination of MaxPool2d and DoubleConv in series
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -44,6 +46,7 @@ class Up(nn.Module):
followed by concatenation of feature map from contracting path,
followed by double 3x3 convolution.
'''

def __init__(self, in_ch, out_ch, bilinear=False):
super().__init__()
self.upsample = None
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class KITTI(Dataset):
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
'''

def __init__(
self,
root_path,
Expand Down Expand Up @@ -120,6 +121,7 @@ class SegModel(pl.LightningModule):

Adam optimizer is used along with Cosine Annealing learning rate scheduler.
'''

def __init__(self, hparams):
super(SegModel, self).__init__()
self.root_path = hparams.root
Expand Down
58 changes: 47 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,33 @@
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable
from typing import Union, Optional, List, Dict, Tuple, Iterable, Generator, Any

alexeykarnachev marked this conversation as resolved.
Show resolved Hide resolved
import torch
from torch import optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch import optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin,
parse_gpu_ids,
determine_root_gpu_device
)
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -449,17 +449,52 @@ def default_attributes(cls):
return args

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
def _get_argparse_args_and_types(cls) -> Generator[Tuple[str, Any, Any], None, None]:
Borda marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_argparse_args_and_types(cls) -> Generator[Tuple[str, Any, Any], None, None]:
def _get_init_arguments_and_types(cls) -> Generator[Tuple[str, Any, Any], None, None]:

sounds a bit better as it can be used elsewhere not only for argparser

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in this case it’s worth to make this method public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also move the allowed_types to the method's input arguments?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, both are good suggestions, lest doing it...

Copy link
Contributor Author

@alexeykarnachev alexeykarnachev Mar 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to make this method more general. Let's a client decides what arguments he needs. For instance, add_argparse_args is a "client" of this method, and it restricts the allowed types by itself.
What are your thoughts on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be nice, just thinking if is better to defien wanted args/types or ignored ones

r"""Generates arguments, which could extend custom `ArgumentParser` object.

trainer_default_params = Trainer.default_attributes()
Yields:
Tuple with 3 values: argument name, argument type and argument default value.
"""
alexeykarnachev marked this conversation as resolved.
Show resolved Hide resolved
import inspect
Borda marked this conversation as resolved.
Show resolved Hide resolved
# Only arguments of these types can extend the `parent_parser`.
allowed_types = (str, float, int, bool)
trainer_default_params = inspect.signature(cls).parameters

# TODO: get "help" from docstring :)
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
try:
possible_arg_types = arg_type.__args__
except AttributeError:
possible_arg_types = (arg_type,)

allowed_arg_types = set(possible_arg_types).intersection(allowed_types)

arg_type = None
for allowed_type in allowed_types:
if allowed_type in allowed_arg_types:
arg_type = allowed_type
break

if arg_type is not None:
yield arg, arg_type, arg_default
else:
log.debug(
'Argument %s has no allowed default type hint and will not be added '
'to the arguments parser. Allowed types: %s', (arg, allowed_types)
)

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
r"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )

# TODO: get "help" from docstring :)
for arg, arg_type, arg_default in cls._get_argparse_args_and_types():
parser.add_argument(
f'--{arg}',
default=trainer_default_params[arg],
default=arg_default,
type=arg_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
Expand Down Expand Up @@ -879,6 +914,7 @@ class _PatchDataLoader(object):
Args:
dataloader: Dataloader object to return when called.
'''

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

Expand Down
47 changes: 47 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def test_dp_output_reduce():

def test_model_checkpoint_options(tmpdir):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath):
open(filepath, 'a').close()

Expand Down Expand Up @@ -644,3 +645,49 @@ def test_default_args(tmpdir):

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5


@pytest.mark.parametrize(
Borda marked this conversation as resolved.
Show resolved Hide resolved
'cli_args',
[
['--accumulate_grad_batches=22'],
['--print_nan_grads=1', '--weights_save_path=./'],
[]
]
)
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and
tests the Trainer initialization correctness.
"""
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

args = parser.parse_args(cli_args)

trainer = Trainer.from_argparse_args(args=args)
assert isinstance(trainer, Trainer)


@pytest.mark.parametrize(
'cli_args',
[
['--callbacks=1', '--logger'],
['--foo', '--bar=1']
]
)
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
"""Asserts thar an error raised in case of passing not default cli arguments."""

class _UnkArgError(Exception):
pass

def _raise():
raise _UnkArgError

parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True)

with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)