Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 4, 2020
1 parent 3e03422 commit 5b1595f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
19 changes: 16 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,10 +1740,11 @@ def module_arguments(self) -> dict:
custom object or dict in which the keys are the union of all argument names in the constructor
and all parent constructors, excluding `self`, `*args` and `**kwargs`.
"""
args = copy.deepcopy(self._module_parents_arguments)
if isinstance(args, dict):
if isinstance(self._module_self_arguments, dict):
args = copy.deepcopy(self._module_parents_arguments)
args.update(self._module_self_arguments)
return args
return args
return copy.deepcopy(self._module_self_arguments)

def save_hyperparameters(self, *args, **kwargs) -> None:
"""
Expand Down Expand Up @@ -1771,6 +1772,18 @@ def save_hyperparameters(self, *args, **kwargs) -> None:
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> OrderedDict(model.module_arguments)
OrderedDict([('arg1', 1), ('arg2', 'abc'), ('arg3', 3.14)])
>>> from collections import OrderedDict
>>> class SingleArgModel(LightningModule):
... def __init__(self, hparams):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(hparams)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.module_arguments
Namespace(p1=1, p2='abc', p3=3.14)
"""
if not args and not kwargs:
self._auto_collect_arguments()
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def test2(self):
A().test()


@pytest.mark.skipif(sys.version_info < (3, 8), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):
class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
Expand Down

0 comments on commit 5b1595f

Please sign in to comment.