Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update hparams, allow OmegaConf #2047

Merged
merged 51 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f5e5e45
DictConf
Borda Jun 1, 2020
8f9b530
inits
Borda Jun 1, 2020
973502f
Apply suggestions from code review
Borda Jun 3, 2020
50e4e6a
wip
Borda Jun 4, 2020
a2bf308
wip
Borda Jun 4, 2020
89a500a
wip
Borda Jun 4, 2020
ae288e0
wip
Borda Jun 4, 2020
0662707
wip
Borda Jun 4, 2020
c7e0b74
wip
Borda Jun 4, 2020
7f38ca0
wip
Borda Jun 4, 2020
771fd30
wip
Borda Jun 4, 2020
dc83ccf
wip
Borda Jun 5, 2020
8ebcaf7
wip
Borda Jun 5, 2020
3761e98
atrib
Borda Jun 5, 2020
f952aae
wip
Borda Jun 5, 2020
f9e2ffa
wip
Borda Jun 5, 2020
bbd471b
wip
Borda Jun 6, 2020
73ab2ee
added hparams test
williamFalcon Jun 6, 2020
5d4fd39
wip
Borda Jun 6, 2020
607b243
wip
Borda Jun 6, 2020
b21b6e6
wip
Borda Jun 6, 2020
16b29ed
wip
Borda Jun 6, 2020
ac1e6a8
wip
Borda Jun 6, 2020
c5fbc25
wip
Borda Jun 6, 2020
4cceca6
wip
Borda Jun 6, 2020
db4204e
wip
Borda Jun 6, 2020
cfe890d
wip
Borda Jun 6, 2020
3095b11
wip
Borda Jun 6, 2020
f585341
wip
Borda Jun 6, 2020
9626ae1
wip
Borda Jun 7, 2020
8a60394
wip
Borda Jun 7, 2020
74a75be
wip
Borda Jun 7, 2020
a356cb0
wip
Borda Jun 7, 2020
364f13d
wip
Borda Jun 7, 2020
3c55c02
wip
Borda Jun 7, 2020
a5924b9
wip
Borda Jun 7, 2020
c06555b
wip
Borda Jun 7, 2020
688a727
wip
Borda Jun 7, 2020
72a93b7
wip
Borda Jun 7, 2020
99685ea
Update test_hparams.py
williamFalcon Jun 7, 2020
446d5e2
added hparams test
williamFalcon Jun 7, 2020
6f996bb
added hparams test
williamFalcon Jun 7, 2020
8546b8c
pep8
williamFalcon Jun 7, 2020
974bb53
pep8
williamFalcon Jun 7, 2020
4e59c9a
pep8
williamFalcon Jun 7, 2020
367d182
docs
williamFalcon Jun 7, 2020
0a7a164
wip
Borda Jun 7, 2020
bc4da6e
wip
Borda Jun 7, 2020
d97b3c8
clean
Borda Jun 7, 2020
38e9686
review @omry
Borda Jun 7, 2020
e532796
Update docs/source/hyperparameters.rst
Borda Jun 7, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an issue where local variables were being collected into module_arguments ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

- Fixed an issue with `auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))
- Fixed an issue with `_auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))
Borda marked this conversation as resolved.
Show resolved Hide resolved

## [0.7.6] - 2020-05-16

Expand Down
95 changes: 81 additions & 14 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import copy
import inspect
import os
import warnings
Expand Down Expand Up @@ -31,6 +32,15 @@
else:
XLA_AVAILABLE = True

ALLOWED_CONFIG_TYPES = (dict, Namespace)
try:
from omegaconf import DictConfig, OmegaConf
except ImportError:
pass
else:
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (DictConfig, OmegaConf)
Borda marked this conversation as resolved.
Show resolved Hide resolved


CHECKPOINT_KEY_MODULE_ARGS = 'module_arguments'


Expand Down Expand Up @@ -77,6 +87,9 @@ def __init__(self, *args, **kwargs):
#: device reference
self._device = torch.device('cpu')

self._module_self_arguments = {}
self._module_parents_arguments = {}

@property
def on_gpu(self):
"""
Expand Down Expand Up @@ -1714,13 +1727,14 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
" and this method will be removed in v1.0.0", DeprecationWarning)
return self.get_progress_bar_dict()

def auto_collect_arguments(self) -> None:
def _auto_collect_arguments(self, frame=None) -> None:
"""
Collect all module arguments in the current constructor and all child constructors.
The child constructors are all the ``__init__`` methods that reach the current class through
(chained) ``super().__init__()`` calls.
"""
frame = inspect.currentframe()
if not frame:
frame = inspect.currentframe()

frame_args = _collect_init_args(frame.f_back, [])
self_arguments = frame_args[-1]
Expand All @@ -1739,19 +1753,70 @@ def module_arguments(self) -> dict:
Aggregate of arguments passed to the constructor of this module and all parents.

Return:
a dict in which the keys are the union of all argument names in the constructor and all
parent constructors, excluding `self`, `*args` and `**kwargs`.
custom object or dict in which the keys are the union of all argument names in the constructor
and all parent constructors, excluding `self`, `*args` and `**kwargs`.
"""
try:
args = dict(self._module_parents_arguments)
if isinstance(self._module_self_arguments, dict):
args = copy.deepcopy(self._module_parents_arguments)
args.update(self._module_self_arguments)
return args
except AttributeError as e:
rank_zero_warn('you called `module.module_arguments` without calling self.auto_collect_arguments()')
return {}
return copy.deepcopy(self._module_self_arguments)

def save_hyperparameters(self, *args, **kwargs) -> None:
"""

>>> from collections import OrderedDict
>>> class ManuallyArgsModel(LightningModule):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assin arguments
... self.save_hyperparameters(arg_name1=arg1, arg_name2=arg2, arg_name3=arg3)
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> OrderedDict(model.module_arguments)
OrderedDict([('arg_name1', 1), ('arg_name2', 'abc'), ('arg_name3', 3.14)])

>>> from collections import OrderedDict
>>> class AutomaticArgsModel(LightningModule):
Borda marked this conversation as resolved.
Show resolved Hide resolved
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> OrderedDict(model.module_arguments)
OrderedDict([('arg1', 1), ('arg2', 'abc'), ('arg3', 3.14)])

>>> class SingleArgModel(LightningModule):
... def __init__(self, hparams):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(hparams)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.module_arguments
Namespace(p1=1, p2='abc', p3=3.14)
"""
if not args and not kwargs:
self._auto_collect_arguments()
return

if args:
if len(args) > 1:
raise ValueError('Only one argument can be passed.')
arg = args[0]
if not isinstance(arg, ALLOWED_CONFIG_TYPES):
raise ValueError(f'Unsupported argument type `{type(arg)}`.')
self._module_self_arguments = copy.copy(arg)

elif kwargs:
self._module_self_arguments = copy.deepcopy(kwargs)


def _collect_init_args(frame, path_args: list) -> list:
def _collect_init_args(frame, path_args: list, inside: bool = False) -> list:
"""
Recursively collects the arguments passed to the child constructors in the inheritance tree.

Expand All @@ -1769,9 +1834,9 @@ def _collect_init_args(frame, path_args: list) -> list:
cls = local_vars['__class__']
spec = inspect.getfullargspec(cls.__init__)
init_parameters = inspect.signature(cls.__init__).parameters
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
exclude_argnames = (
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args'
)
Expand All @@ -1783,6 +1848,8 @@ def _collect_init_args(frame, path_args: list) -> list:

# recursive update
path_args.append(local_args)
return _collect_init_args(frame.f_back, path_args)
return _collect_init_args(frame.f_back, path_args, inside=True)
elif not inside:
return _collect_init_args(frame.f_back, path_args, inside)
else:
return path_args
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@
list, tuple, set, dict,
Namespace, # for back compatibility
)
try:
from omegaconf import DictConfig, OmegaConf
except ImportError:
pass
else:
PRIMITIVE_TYPES = PRIMITIVE_TYPES + (DictConfig, OmegaConf)
Borda marked this conversation as resolved.
Show resolved Hide resolved


class TrainerIOMixin(ABC):
Expand Down
2 changes: 1 addition & 1 deletion tests/base/model_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self,
**kwargs) -> object:
# init superclass
super().__init__()
self.auto_collect_arguments()
self._auto_collect_arguments()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use

self.save_hyperparameters()


self.drop_prob = drop_prob
self.batch_size = batch_size
Expand Down
127 changes: 101 additions & 26 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,17 @@
import os
import sys
from argparse import Namespace

import pytest
import torch
from omegaconf import OmegaConf
from packaging import version
from omegaconf import OmegaConf, DictConfig

from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.lightning import CHECKPOINT_KEY_MODULE_ARGS
from tests.base import EvalModelTemplate


class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
super().__init__()
self.ogc = ogc
self.size = ogc.list[0]


def test_class_nesting(tmpdir):
def test_class_nesting():

class MyModule(LightningModule):
def forward(self):
Expand Down Expand Up @@ -47,6 +40,12 @@ def test2(self):

@pytest.mark.xfail(sys.version_info >= (3, 8), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):
class OmegaConfModel(EvalModelTemplate):
def __init__(self, ogc):
super().__init__()
self.ogc = ogc
self.size = ogc.list[0]

conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
model = OmegaConfModel(conf)

Expand All @@ -64,8 +63,11 @@ class SubClassEvalModel(EvalModelTemplate):

def __init__(self, *args, subclass_arg=1200, **kwargs):
super().__init__(*args, **kwargs)
self.subclass_arg = subclass_arg
self.auto_collect_arguments()
self.save_hyperparameters()


class SubSubClassEvalModel(SubClassEvalModel):
pass


class UnconventionalArgsEvalModel(EvalModelTemplate):
Expand All @@ -74,21 +76,21 @@ class UnconventionalArgsEvalModel(EvalModelTemplate):
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.other_arg = other_arg
obj.save_hyperparameters()
other_arg = 321
obj.auto_collect_arguments()


class SubSubClassEvalModel(SubClassEvalModel):
pass


class AggSubClassEvalModel(SubClassEvalModel):

def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
super().__init__(*args, **kwargs)
self.my_loss = my_loss
self.auto_collect_arguments()
self.save_hyperparameters()


class DictConfSubClassEvalModel(SubClassEvalModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param='something')), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()


@pytest.mark.parametrize("cls", [
Expand All @@ -97,10 +99,15 @@ def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
SubSubClassEvalModel,
AggSubClassEvalModel,
UnconventionalArgsEvalModel,
DictConfSubClassEvalModel,
])
def test_collect_init_arguments(tmpdir, cls):
""" Test that the model automatically saves the arguments passed into the constructor """
extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {}
extra_args = {}
if cls is AggSubClassEvalModel:
extra_args.update(my_loss=torch.nn.CosineEmbeddingLoss())
elif cls is DictConfSubClassEvalModel:
extra_args.update(dict_conf=OmegaConf.create(dict(my_param='anything')))

model = cls(**extra_args)
assert model.batch_size == 32
Expand All @@ -116,9 +123,7 @@ def test_collect_init_arguments(tmpdir, cls):
# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
trainer.fit(model)
raw_checkpoint_path = os.listdir(trainer.checkpoint_callback.dirpath)
raw_checkpoint_path = [x for x in raw_checkpoint_path if '.ckpt' in x][0]
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
raw_checkpoint_path = _raw_checkpoint_path(trainer)

raw_checkpoint = torch.load(raw_checkpoint_path)
assert CHECKPOINT_KEY_MODULE_ARGS in raw_checkpoint
Expand All @@ -131,11 +136,22 @@ def test_collect_init_arguments(tmpdir, cls):
if isinstance(model, AggSubClassEvalModel):
assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss)

if isinstance(model, DictConfSubClassEvalModel):
assert isinstance(model.dict_conf, DictConfig)
assert model.dict_conf == 'anything'

# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.batch_size == 99


def _raw_checkpoint_path(trainer) -> str:
raw_checkpoint_paths = os.listdir(trainer.checkpoint_callback.dirpath)
raw_checkpoint_path = [x for x in raw_checkpoint_paths if '.ckpt' in x][0]
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
return raw_checkpoint_path


class LocalVariableModel1(EvalModelTemplate):
""" This model has the super().__init__() call at the end. """

Expand All @@ -147,14 +163,14 @@ def __init__(self, arg1, arg2, *args, **kwargs):


class LocalVariableModel2(EvalModelTemplate):
""" This model has the auto_collect_arguments() call at the end. """
""" This model has the _auto_collect_arguments() call at the end. """

def __init__(self, arg1, arg2, *args, **kwargs):
super().__init__(*args, **kwargs)
self.argument1 = arg1 # arg2 intentionally not set
arg1 = 'overwritten'
local_var = 1234
self.auto_collect_arguments() # this is intentionally here at the end
self._auto_collect_arguments() # this is intentionally here at the end


@pytest.mark.parametrize("cls", [
Expand All @@ -167,3 +183,62 @@ def test_collect_init_arguments_with_local_vars(cls):
assert 'local_var' not in model.module_arguments
assert model.module_arguments['arg1'] == 'overwritten'
assert model.module_arguments['arg2'] == 2


class NamespaceArgModel(EvalModelTemplate):
def __init__(self, hparams: Namespace):
super().__init__()
self.save_hyperparameters(hparams)


class DictArgModel(EvalModelTemplate):
def __init__(self, some_dict: dict):
super().__init__()
self.save_hyperparameters(some_dict)


class OmegaConfArgModel(EvalModelTemplate):
def __init__(self, conf: OmegaConf):
super().__init__()
self.save_hyperparameters(conf)


@pytest.mark.parametrize("cls,config", [
(NamespaceArgModel, Namespace(my_arg=42)),
(DictArgModel, dict(my_arg=42)),
(OmegaConfArgModel, OmegaConf.create(dict(my_arg=42))),
])
def test_single_config_models(tmpdir, cls, config):
""" Test that the model automatically saves the arguments passed into the constructor """
model = cls(config)

# verify that the checkpoint saved the correct values
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
trainer.fit(model)

# verify that model loads correctly
raw_checkpoint_path = _raw_checkpoint_path(trainer)
model = cls.load_from_checkpoint(raw_checkpoint_path)
assert model.module_arguments == config


class AnotherArgModel(EvalModelTemplate):
def __init__(self, arg1):
super().__init__()
self.save_hyperparameters(arg1)


class OtherArgsModel(EvalModelTemplate):
def __init__(self, arg1, arg2):
super().__init__()
self.save_hyperparameters(arg1, arg2)


@pytest.mark.parametrize("cls,config", [
(AnotherArgModel, dict(arg1=42)),
(OtherArgsModel, dict(arg1=42, arg2='abc')),
])
def test_single_config_models_fail(tmpdir, cls, config):
""" Test fail on passing unsupported config type. """
with pytest.raises(ValueError):
_ = cls(**config)