-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix local variables being collected into module_arguments dict #2048
Changes from all commits
d711672
26fcf55
51da5e1
ddeb191
66ca8c4
f9a3e7d
4d719dd
09fa282
790f192
d4747c1
8284dcc
e9492e1
79a6023
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1698,22 +1698,34 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: | |
" and this method will be removed in v1.0.0", DeprecationWarning) | ||
return self.get_progress_bar_dict() | ||
|
||
def auto_collect_arguments(self): | ||
"""Collect all arguments module arguments.""" | ||
def auto_collect_arguments(self) -> None: | ||
""" | ||
Collect all module arguments in the current constructor and all child constructors. | ||
The child constructors are all the ``__init__`` methods that reach the current class through | ||
(chained) ``super().__init__()`` calls. | ||
""" | ||
frame = inspect.currentframe() | ||
|
||
frame_args = _collect_init_args(frame.f_back, []) | ||
child = _get_latest_child(frame) | ||
self_arguments = frame_args[-1] | ||
|
||
# set module_arguments in child | ||
child._module_self_arguments = frame_args[-1] | ||
child._module_parents_arguments = {} | ||
self._module_self_arguments = self_arguments | ||
self._module_parents_arguments = {} | ||
|
||
# add all arguments from parents | ||
for args in frame_args[:-1]: | ||
child._module_parents_arguments.update(args) | ||
self._module_parents_arguments.update(args) | ||
|
||
@property | ||
def module_arguments(self) -> dict: | ||
"""Aggregate this module and all parents arguments.""" | ||
""" | ||
Aggregate of arguments passed to the constructor of this module and all parents. | ||
|
||
Return: | ||
a dict in which the keys are the union of all argument names in the constructor and all | ||
parent constructors, excluding `self`, `*args` and `**kwargs`. | ||
""" | ||
try: | ||
args = dict(self._module_parents_arguments) | ||
args.update(self._module_self_arguments) | ||
|
@@ -1724,26 +1736,37 @@ def module_arguments(self) -> dict: | |
|
||
|
||
def _collect_init_args(frame, path_args: list) -> list: | ||
"""Recursive search for all children.""" | ||
if '__class__' in frame.f_locals: | ||
local_args = dict(frame.f_locals) | ||
local_args.update(local_args.get('kwargs', {})) | ||
local_args = {k: v for k, v in local_args.items() | ||
if k not in ('args', 'kwargs', 'self', '__class__', 'frame', 'frame_args')} | ||
# if 'hparams' in local_args: | ||
# # back compatible hparams as single argument | ||
# hparams = local_args.get('hparams') | ||
# local_args.update(vars(hparams) if isinstance(hparams, Namespace) else hparams) | ||
""" | ||
Recursively collects the arguments passed to the child constructors in the inheritance tree. | ||
|
||
Args: | ||
frame: the current stack frame | ||
path_args: a list of dictionaries containing the constructor args in all parent classes | ||
|
||
Return: | ||
A list of dictionaries where each dictionary contains the arguments passed to the | ||
constructor at that level. The last entry corresponds to the constructor call of the | ||
most specific class in the hierarchy. | ||
""" | ||
_, _, _, local_vars = inspect.getargvalues(frame) | ||
if '__class__' in local_vars: | ||
cls = local_vars['__class__'] | ||
spec = inspect.getfullargspec(cls.__init__) | ||
init_parameters = inspect.signature(cls.__init__).parameters | ||
self_identifier = spec.args[0] # "self" unless user renames it (always first arg) | ||
varargs_identifier = spec.varargs # by convention this is named "*args" | ||
kwargs_identifier = spec.varkw # by convention this is named "**kwargs" | ||
Borda marked this conversation as resolved.
Show resolved
Hide resolved
|
||
exclude_argnames = ( | ||
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why also ignore Shouldn't we collect them as well and use them for reinstantiation/loading? Otherwise the module might be instantiated with different params which could cause incompatibilities with module and loaded checkpoint. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because if we have this class class B(object):
def __init__(self, *args, **kwargs):
super().__init__()
something = kwargs.get('something')
self.auto_collect_arguments() and we call it like this It was like this before, this has not changed. The code that I added there just makes it so that the name *args and **kwargs is not hardcoded, but can be named whatever the user wants. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's the explanation why the ignore, and the second part of your question: |
||
) | ||
|
||
# only collect variables that appear in the signature | ||
local_args = {k: local_vars[k] for k in init_parameters.keys()} | ||
local_args.update(local_args.get(kwargs_identifier, {})) | ||
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} | ||
|
||
# recursive update | ||
path_args.append(local_args) | ||
return _collect_init_args(frame.f_back, path_args) | ||
else: | ||
return path_args | ||
|
||
|
||
def _get_latest_child(frame, child: object = None) -> object: | ||
"""Recursive search for lowest child.""" | ||
if 'self' in frame.f_locals: | ||
return _get_latest_child(frame.f_back, frame.f_locals['self']) | ||
else: | ||
return child |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wasn't there a reason with inheritance for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested it and I found that child is always the specific type, i.e. "self", even if self.auto_collect() is called in a super class.
So in a concrete example:
so really it does not matter where we collect "self", it will always be the child (leaf node) in the inheritance tree.
The previous inheritance tests still pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
child = _get_latest_child(frame)
does not break anything, it's just redudant as far as I can tell.