Skip to content

Commit

Permalink
Custom argparser extension with Trainer arguments (argument types add…
Browse files Browse the repository at this point in the history
…ed) (#1147)

* `add_argparse_args` method fixed (argument types added)

* CHANGELOG.md upd

* autopep8 fixes

* --gpus=0 removed from test (for ci tests)

* typo fixed

* reduce on plateau scheduler fixed

* Trainer cli related tests moved to test_trainer_cli.py

* refactored: get_init_arguments_and_types is a public classmethod of the Trainer now

* test_get_init_arguments_and_types added

* autopep8 fixes

* Trainer cli related tests moved to test_trainer_cli.py

* refactored: get_init_arguments_and_types is a public classmethod of the Trainer now

* test_get_init_arguments_and_types added

* autopep8 fixes

* Trainer cli related tests moved to test_trainer_cli.py

* refactored: get_init_arguments_and_types is a public classmethod of the Trainer now

* test_get_init_arguments_and_types added

* autopep8 fixes

* Trainer cli related tests moved to test_trainer_cli.py

* test_get_init_arguments_and_types added

* autopep8 fixes

* Apply suggestions from code review

* cosmetics

* cosmetics

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets)

* deprecated args are now ignored in argparser

* get_deprecated_arg_names small refactor

* get_deprecated_arg_names bug fixed

* Trainer cli related tests moved to test_trainer_cli.py

* refactored: get_init_arguments_and_types is a public classmethod of the Trainer now

* test_get_init_arguments_and_types added

* autopep8 fixes

* Trainer cli related tests moved to test_trainer_cli.py

* autopep8 fixes

* Trainer cli related tests moved to test_trainer_cli.py

* Trainer cli related tests moved to test_trainer_cli.py

* test_get_init_arguments_and_types added

* autopep8 fixes

* autopep8 fixes

* Apply suggestions from code review

* cosmetics

* cosmetics

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets)

* deprecated args are now ignored in argparser

* get_deprecated_arg_names small refactor

* get_deprecated_arg_names bug fixed

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Joe Davison <joe@huggingface.co>

* Update pytorch_lightning/trainer/trainer.py

Co-Authored-By: Joe Davison <joe@huggingface.co>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Joe Davison <joe@huggingface.co>
Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
4 people committed Mar 24, 2020
1 parent f6dabc2 commit ced662f
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 43 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed bug related to type checking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))

- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class UNet(nn.Module):
bilinear (bool) - Whether to use bilinear interpolation or transposed
convolutions for upsampling.
'''

def __init__(self, num_classes=19, bilinear=False):
super().__init__()
self.layer1 = DoubleConv(3, 64)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class DoubleConv(nn.Module):
Double Convolution and BN and ReLU
(3x3 conv -> BN -> ReLU) ** 2
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -27,6 +28,7 @@ class Down(nn.Module):
'''
Combination of MaxPool2d and DoubleConv in series
'''

def __init__(self, in_ch, out_ch):
super().__init__()
self.net = nn.Sequential(
Expand All @@ -44,6 +46,7 @@ class Up(nn.Module):
followed by concatenation of feature map from contracting path,
followed by double 3x3 convolution.
'''

def __init__(self, in_ch, out_ch, bilinear=False):
super().__init__()
self.upsample = None
Expand Down
2 changes: 2 additions & 0 deletions pl_examples/full_examples/semantic_segmentation/semseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class KITTI(Dataset):
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
(mask does not usually require transforms, but they can be implemented in a similar way).
'''

def __init__(
self,
root_path,
Expand Down Expand Up @@ -120,6 +121,7 @@ class SegModel(pl.LightningModule):
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
'''

def __init__(self, hparams):
super(SegModel, self).__init__()
self.root_path = hparams.root
Expand Down
112 changes: 92 additions & 20 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,30 @@
import sys
import warnings
from argparse import ArgumentParser
from typing import Union, Optional, List, Dict, Tuple, Iterable
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
import distutils

import torch
from torch import optim
import torch.distributed as torch_distrib
import torch.multiprocessing as mp
from torch import optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import (
TrainerDPMixin,
parse_gpu_ids,
determine_root_gpu_device
)
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -70,6 +67,11 @@ class Trainer(
TrainerCallbackHookMixin,
TrainerDeprecatedAPITillVer0_8,
):
DEPRECATED_IN_0_8 = (
'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs',
'add_row_log_interval', 'nb_sanity_val_steps'
)
DEPRECATED_IN_0_9 = ('use_amp',)

def __init__(
self,
Expand Down Expand Up @@ -466,21 +468,91 @@ def default_attributes(cls):

return args

@classmethod
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.
Returns:
List with tuples of 3 values:
(argument name, set with argument types, argument default value).
Examples:
>>> args = Trainer.get_init_arguments_and_types()
>>> import pprint
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
[('accumulate_grad_batches',
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
1),
...
('callbacks', (<class 'pytorch_lightning.callbacks.base.Callback'>,), []),
('check_val_every_n_epoch', (<class 'int'>,), 1),
...
('max_epochs', (<class 'int'>,), 1000),
...
('precision', (<class 'int'>,), 32),
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
<class 'NoneType'>),
None),
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []
for arg in trainer_default_params:
arg_type = trainer_default_params[arg].annotation
arg_default = trainer_default_params[arg].default
try:
arg_types = tuple(arg_type.__args__)
except AttributeError:
arg_types = (arg_type,)

name_type_default.append((arg, arg_types, arg_default))

return name_type_default

@classmethod
def get_deprecated_arg_names(cls) -> List:
"""Returns a list with deprecated Trainer arguments."""
depr_arg_names = []
for name, val in cls.__dict__.items():
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
depr_arg_names.extend(val)
return depr_arg_names

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
"""Extend existing argparse by default `Trainer` attributes."""
parser = ArgumentParser(parents=[parent_parser], add_help=False)
r"""Extends existing argparse by default `Trainer` attributes.
trainer_default_params = Trainer.default_attributes()
Args:
parent_parser:
The custom cli arguments parser, which will be extended by
the Trainer default arguments.
Only arguments of the allowed types (str, float, int, bool) will
extend the `parent_parser`.
"""
parser = ArgumentParser(parents=[parent_parser], add_help=False, )

depr_arg_names = cls.get_deprecated_arg_names()

allowed_types = (str, float, int, bool)
# TODO: get "help" from docstring :)
for arg in trainer_default_params:
parser.add_argument(
f'--{arg}',
default=trainer_default_params[arg],
dest=arg,
help='autogenerated by pl.Trainer'
)
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
if arg not in depr_arg_names:
for allowed_type in allowed_types:
if allowed_type in arg_types:
if allowed_type is bool:
allowed_type = lambda x: bool(distutils.util.strtobool(x))
parser.add_argument(
f'--{arg}',
default=arg_default,
type=allowed_type,
dest=arg,
help='autogenerated by pl.Trainer'
)
break

return parser

Expand Down
24 changes: 2 additions & 22 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import glob
import math
import os
from argparse import ArgumentParser, Namespace
from unittest import mock
from argparse import Namespace

import pytest
import torch
Expand Down Expand Up @@ -251,6 +250,7 @@ def test_dp_output_reduce():

def test_model_checkpoint_options(tmpdir):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath):
open(filepath, 'a').close()

Expand Down Expand Up @@ -624,23 +624,3 @@ def test_epoch_end(self, outputs):

model = LightningTestModel(hparams)
Trainer().test(model)


@mock.patch('argparse.ArgumentParser.parse_args',
return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5
85 changes: 85 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import inspect
from argparse import ArgumentParser, Namespace
from unittest import mock

import pytest

import tests.models.utils as tutils
from pytorch_lightning import Trainer


@mock.patch('argparse.ArgumentParser.parse_args',
return_value=Namespace(**Trainer.default_attributes()))
def test_default_args(tmpdir):
"""Tests default argument parser for Trainer"""
tutils.reset_seed()

# logger file to get meta
logger = tutils.get_test_tube_logger(tmpdir, False)

parser = ArgumentParser(add_help=False)
args = parser.parse_args()
args.logger = logger

args.max_epochs = 5
trainer = Trainer.from_argparse_args(args)

assert isinstance(trainer, Trainer)
assert trainer.max_epochs == 5


@pytest.mark.parametrize('cli_args', [
['--accumulate_grad_batches=22'],
['--print_nan_grads=1', '--weights_save_path=./'],
[]
])
def test_add_argparse_args_redefined(cli_args):
"""Redefines some default Trainer arguments via the cli and
tests the Trainer initialization correctness.
"""
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

args = parser.parse_args(cli_args)

# Check few deprecated args are not in namespace:
for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'):
assert depr_name not in args

trainer = Trainer.from_argparse_args(args=args)
assert isinstance(trainer, Trainer)


def test_get_init_arguments_and_types():
"""Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod."""
args = Trainer.get_init_arguments_and_types()
parameters = inspect.signature(Trainer).parameters
assert len(parameters) == len(args)
for arg in args:
assert parameters[arg[0]].default == arg[2]

kwargs = {arg[0]: arg[2] for arg in args}
trainer = Trainer(**kwargs)
assert isinstance(trainer, Trainer)


@pytest.mark.parametrize('cli_args', [
['--callbacks=1', '--logger'],
['--foo', '--bar=1']
])
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
"""Asserts thar an error raised in case of passing not default cli arguments."""

class _UnkArgError(Exception):
pass

def _raise():
raise _UnkArgError

parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parent_parser=parser)

monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True)

with pytest.raises(_UnkArgError):
parser.parse_args(cli_args)

0 comments on commit ced662f

Please sign in to comment.