-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Custom argparser extension with Trainer arguments (argument types add…
…ed) (#1147) * `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: William Falcon <waf2107@columbia.edu>
- Loading branch information
1 parent
f6dabc2
commit ced662f
Showing
7 changed files
with
188 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import inspect | ||
from argparse import ArgumentParser, Namespace | ||
from unittest import mock | ||
|
||
import pytest | ||
|
||
import tests.models.utils as tutils | ||
from pytorch_lightning import Trainer | ||
|
||
|
||
@mock.patch('argparse.ArgumentParser.parse_args', | ||
return_value=Namespace(**Trainer.default_attributes())) | ||
def test_default_args(tmpdir): | ||
"""Tests default argument parser for Trainer""" | ||
tutils.reset_seed() | ||
|
||
# logger file to get meta | ||
logger = tutils.get_test_tube_logger(tmpdir, False) | ||
|
||
parser = ArgumentParser(add_help=False) | ||
args = parser.parse_args() | ||
args.logger = logger | ||
|
||
args.max_epochs = 5 | ||
trainer = Trainer.from_argparse_args(args) | ||
|
||
assert isinstance(trainer, Trainer) | ||
assert trainer.max_epochs == 5 | ||
|
||
|
||
@pytest.mark.parametrize('cli_args', [ | ||
['--accumulate_grad_batches=22'], | ||
['--print_nan_grads=1', '--weights_save_path=./'], | ||
[] | ||
]) | ||
def test_add_argparse_args_redefined(cli_args): | ||
"""Redefines some default Trainer arguments via the cli and | ||
tests the Trainer initialization correctness. | ||
""" | ||
parser = ArgumentParser(add_help=False) | ||
parser = Trainer.add_argparse_args(parent_parser=parser) | ||
|
||
args = parser.parse_args(cli_args) | ||
|
||
# Check few deprecated args are not in namespace: | ||
for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'): | ||
assert depr_name not in args | ||
|
||
trainer = Trainer.from_argparse_args(args=args) | ||
assert isinstance(trainer, Trainer) | ||
|
||
|
||
def test_get_init_arguments_and_types(): | ||
"""Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod.""" | ||
args = Trainer.get_init_arguments_and_types() | ||
parameters = inspect.signature(Trainer).parameters | ||
assert len(parameters) == len(args) | ||
for arg in args: | ||
assert parameters[arg[0]].default == arg[2] | ||
|
||
kwargs = {arg[0]: arg[2] for arg in args} | ||
trainer = Trainer(**kwargs) | ||
assert isinstance(trainer, Trainer) | ||
|
||
|
||
@pytest.mark.parametrize('cli_args', [ | ||
['--callbacks=1', '--logger'], | ||
['--foo', '--bar=1'] | ||
]) | ||
def test_add_argparse_args_redefined_error(cli_args, monkeypatch): | ||
"""Asserts thar an error raised in case of passing not default cli arguments.""" | ||
|
||
class _UnkArgError(Exception): | ||
pass | ||
|
||
def _raise(): | ||
raise _UnkArgError | ||
|
||
parser = ArgumentParser(add_help=False) | ||
parser = Trainer.add_argparse_args(parent_parser=parser) | ||
|
||
monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True) | ||
|
||
with pytest.raises(_UnkArgError): | ||
parser.parse_args(cli_args) |