From 5b1595f0fe64175725baf9e08bb4e43a8fabcd57 Mon Sep 17 00:00:00 2001 From: Jirka Date: Thu, 4 Jun 2020 23:34:40 +0200 Subject: [PATCH] wip --- pytorch_lightning/core/lightning.py | 19 ++++++++++++++++--- tests/models/test_hparams.py | 1 - 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 662f1440ba9591..ce201959eb9075 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1740,10 +1740,11 @@ def module_arguments(self) -> dict: custom object or dict in which the keys are the union of all argument names in the constructor and all parent constructors, excluding `self`, `*args` and `**kwargs`. """ - args = copy.deepcopy(self._module_parents_arguments) - if isinstance(args, dict): + if isinstance(self._module_self_arguments, dict): + args = copy.deepcopy(self._module_parents_arguments) args.update(self._module_self_arguments) - return args + return args + return copy.deepcopy(self._module_self_arguments) def save_hyperparameters(self, *args, **kwargs) -> None: """ @@ -1771,6 +1772,18 @@ def save_hyperparameters(self, *args, **kwargs) -> None: >>> model = AutomaticArgsModel(1, 'abc', 3.14) >>> OrderedDict(model.module_arguments) OrderedDict([('arg1', 1), ('arg2', 'abc'), ('arg3', 3.14)]) + + >>> from collections import OrderedDict + >>> class SingleArgModel(LightningModule): + ... def __init__(self, hparams): + ... super().__init__() + ... # manually assign single argument + ... self.save_hyperparameters(hparams) + ... def forward(self, *args, **kwargs): + ... ... + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) + >>> model.module_arguments + Namespace(p1=1, p2='abc', p3=3.14) """ if not args and not kwargs: self._auto_collect_arguments() diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 35a58e33ff0614..a58184f9075b84 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -38,7 +38,6 @@ def test2(self): A().test() -@pytest.mark.skipif(sys.version_info < (3, 8), reason='OmegaConf only for Python >= 3.8') def test_omegaconf(tmpdir): class OmegaConfModel(EvalModelTemplate): def __init__(self, ogc):