Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix local variables being collected into module_arguments dict #2048

Merged
merged 13 commits into from
Jun 4, 2020
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Allow use of same `WandbLogger` instance for multiple training loops ([#2055](https://github.com/PyTorchLightning/pytorch-lightning/pull/2055))

- Fixed an issue where local variables were being collected into module_arguments ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

- Fixed an issue with `auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048))

## [0.7.6] - 2020-05-16

### Added
Expand Down
73 changes: 48 additions & 25 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,22 +1698,34 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]:
" and this method will be removed in v1.0.0", DeprecationWarning)
return self.get_progress_bar_dict()

def auto_collect_arguments(self):
"""Collect all arguments module arguments."""
def auto_collect_arguments(self) -> None:
"""
Collect all module arguments in the current constructor and all child constructors.
The child constructors are all the ``__init__`` methods that reach the current class through
(chained) ``super().__init__()`` calls.
"""
frame = inspect.currentframe()

frame_args = _collect_init_args(frame.f_back, [])
child = _get_latest_child(frame)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't there a reason with inheritance for this?

Copy link
Member Author

@awaelchli awaelchli Jun 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested it and I found that child is always the specific type, i.e. "self", even if self.auto_collect() is called in a super class.
So in a concrete example:

class A(object):
    def __init__(self, *args, **kwargs):
        print(self)


class B(A):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        print(self)

B()  # prints twice <__main__.B object at 0x0000021540C94288>

so really it does not matter where we collect "self", it will always be the child (leaf node) in the inheritance tree.
The previous inheritance tests still pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The child = _get_latest_child(frame) does not break anything, it's just redudant as far as I can tell.

self_arguments = frame_args[-1]

# set module_arguments in child
child._module_self_arguments = frame_args[-1]
child._module_parents_arguments = {}
self._module_self_arguments = self_arguments
self._module_parents_arguments = {}

# add all arguments from parents
for args in frame_args[:-1]:
child._module_parents_arguments.update(args)
self._module_parents_arguments.update(args)

@property
def module_arguments(self) -> dict:
"""Aggregate this module and all parents arguments."""
"""
Aggregate of arguments passed to the constructor of this module and all parents.

Return:
a dict in which the keys are the union of all argument names in the constructor and all
parent constructors, excluding `self`, `*args` and `**kwargs`.
"""
try:
args = dict(self._module_parents_arguments)
args.update(self._module_self_arguments)
Expand All @@ -1724,26 +1736,37 @@ def module_arguments(self) -> dict:


def _collect_init_args(frame, path_args: list) -> list:
"""Recursive search for all children."""
if '__class__' in frame.f_locals:
local_args = dict(frame.f_locals)
local_args.update(local_args.get('kwargs', {}))
local_args = {k: v for k, v in local_args.items()
if k not in ('args', 'kwargs', 'self', '__class__', 'frame', 'frame_args')}
# if 'hparams' in local_args:
# # back compatible hparams as single argument
# hparams = local_args.get('hparams')
# local_args.update(vars(hparams) if isinstance(hparams, Namespace) else hparams)
"""
Recursively collects the arguments passed to the child constructors in the inheritance tree.

Args:
frame: the current stack frame
path_args: a list of dictionaries containing the constructor args in all parent classes

Return:
A list of dictionaries where each dictionary contains the arguments passed to the
constructor at that level. The last entry corresponds to the constructor call of the
most specific class in the hierarchy.
"""
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' in local_vars:
cls = local_vars['__class__']
spec = inspect.getfullargspec(cls.__init__)
init_parameters = inspect.signature(cls.__init__).parameters
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
Borda marked this conversation as resolved.
Show resolved Hide resolved
exclude_argnames = (
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why also ignore *args and **kwargs?

Shouldn't we collect them as well and use them for reinstantiation/loading? Otherwise the module might be instantiated with different params which could cause incompatibilities with module and loaded checkpoint.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if we have this class

class B(object):

    def __init__(self, *args, **kwargs):
        super().__init__()
        something = kwargs.get('something')
        self.auto_collect_arguments()

and we call it like this
B(something=1, other=2)
then we want to module arguments to be
dict(something=1, other=2)
and not
dict(args=[], kwargs={something=1, other=2})

It was like this before, this has not changed. The code that I added there just makes it so that the name *args and **kwargs is not hardcoded, but can be named whatever the user wants.
The Pyhton inspection magic makes it possible to determine which is which.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the explanation why the ignore, and the second part of your question:
The args inside **kwargs are still collected, see this line
local_args.update(local_args.get('kwargs', {}))

)

# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_identifier, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}

# recursive update
path_args.append(local_args)
return _collect_init_args(frame.f_back, path_args)
else:
return path_args


def _get_latest_child(frame, child: object = None) -> object:
"""Recursive search for lowest child."""
if 'self' in frame.f_locals:
return _get_latest_child(frame.f_back, frame.f_locals['self'])
else:
return child
63 changes: 52 additions & 11 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ def test2(self):
A().test()


@pytest.mark.skipif(sys.version_info < (3, 8), reason='OmegaConf only for Python >= 3.8')
def test_omegaconf(tmpdir):

# ogc only for 3.8
major = sys.version_info[0]
minor = sys.version_info[1]
if major < 3 and minor < 8:
return

conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
model = OmegaConfModel(conf)

Expand All @@ -73,6 +67,17 @@ def __init__(self, *args, subclass_arg=1200, **kwargs):
self.auto_collect_arguments()


class UnconventionalArgsEvalModel(EvalModelTemplate):
""" A model that has unconventional names for "self", "*args" and "**kwargs". """

def __init__(obj, *more_args, other_arg=300, **more_kwargs):
# intentionally named obj
super().__init__(*more_args, **more_kwargs)
obj.other_arg = other_arg
other_arg = 321
obj.auto_collect_arguments()


class SubSubClassEvalModel(SubClassEvalModel):
pass

Expand All @@ -85,10 +90,13 @@ def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
self.auto_collect_arguments()


@pytest.mark.parametrize("cls", [EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel])
@pytest.mark.parametrize("cls", [
EvalModelTemplate,
SubClassEvalModel,
SubSubClassEvalModel,
AggSubClassEvalModel,
UnconventionalArgsEvalModel,
])
def test_collect_init_arguments(tmpdir, cls):
""" Test that the model automatically saves the arguments passed into the constructor """
extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {}
Expand Down Expand Up @@ -125,3 +133,36 @@ def test_collect_init_arguments(tmpdir, cls):
# verify that we can overwrite whatever we want
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.batch_size == 99


class LocalVariableModel1(EvalModelTemplate):
""" This model has the super().__init__() call at the end. """

def __init__(self, arg1, arg2, *args, **kwargs):
self.argument1 = arg1 # arg2 intentionally not set
arg1 = 'overwritten'
local_var = 1234
super().__init__(*args, **kwargs) # this is intentionally here at the end


class LocalVariableModel2(EvalModelTemplate):
""" This model has the auto_collect_arguments() call at the end. """

def __init__(self, arg1, arg2, *args, **kwargs):
super().__init__(*args, **kwargs)
self.argument1 = arg1 # arg2 intentionally not set
arg1 = 'overwritten'
local_var = 1234
self.auto_collect_arguments() # this is intentionally here at the end


@pytest.mark.parametrize("cls", [
LocalVariableModel1,
LocalVariableModel2,
])
def test_collect_init_arguments_with_local_vars(cls):
""" Tests that only the arguments are collected and not local variables. """
model = cls(arg1=1, arg2=2)
assert 'local_var' not in model.module_arguments
assert model.module_arguments['arg1'] == 'overwritten'
assert model.module_arguments['arg2'] == 2