From 237b83dada9d46c29a09e5eb580fb07a0f316b45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 4 Jun 2020 14:35:50 +0200 Subject: [PATCH] Fix local variables being collected into module_arguments dict (#2048) * do not include local vars in auto collection * add test * add test for model with "self" renamed to "obj" * skip decorator * changelog * changelog * update docs * remove obsolete child collection * generalize **args, **kwargs names * docs * also update varargs passed in * Revert "also update varargs passed in" This reverts commit 3d7a30dbee07a513ee13e1cc3e08ca5ccdb85734. * update test --- CHANGELOG.md | 4 ++ pytorch_lightning/core/lightning.py | 73 +++++++++++++++++++---------- tests/models/test_hparams.py | 63 ++++++++++++++++++++----- 3 files changed, 104 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6fda50aa5f85a7..3f8487cc742acf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,6 +76,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Allow use of same `WandbLogger` instance for multiple training loops ([#2055](https://github.com/PyTorchLightning/pytorch-lightning/pull/2055)) +- Fixed an issue where local variables were being collected into module_arguments ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048)) + +- Fixed an issue with `auto_collect_arguments` collecting local variables that are not constructor arguments and not working for signatures that have the instance not named `self` ([#2048](https://github.com/PyTorchLightning/pytorch-lightning/pull/2048)) + ## [0.7.6] - 2020-05-16 ### Added diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a419c80c5e3f2d..720f0026e095c2 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1698,22 +1698,34 @@ def get_tqdm_dict(self) -> Dict[str, Union[int, str]]: " and this method will be removed in v1.0.0", DeprecationWarning) return self.get_progress_bar_dict() - def auto_collect_arguments(self): - """Collect all arguments module arguments.""" + def auto_collect_arguments(self) -> None: + """ + Collect all module arguments in the current constructor and all child constructors. + The child constructors are all the ``__init__`` methods that reach the current class through + (chained) ``super().__init__()`` calls. + """ frame = inspect.currentframe() frame_args = _collect_init_args(frame.f_back, []) - child = _get_latest_child(frame) + self_arguments = frame_args[-1] # set module_arguments in child - child._module_self_arguments = frame_args[-1] - child._module_parents_arguments = {} + self._module_self_arguments = self_arguments + self._module_parents_arguments = {} + + # add all arguments from parents for args in frame_args[:-1]: - child._module_parents_arguments.update(args) + self._module_parents_arguments.update(args) @property def module_arguments(self) -> dict: - """Aggregate this module and all parents arguments.""" + """ + Aggregate of arguments passed to the constructor of this module and all parents. + + Return: + a dict in which the keys are the union of all argument names in the constructor and all + parent constructors, excluding `self`, `*args` and `**kwargs`. + """ try: args = dict(self._module_parents_arguments) args.update(self._module_self_arguments) @@ -1724,26 +1736,37 @@ def module_arguments(self) -> dict: def _collect_init_args(frame, path_args: list) -> list: - """Recursive search for all children.""" - if '__class__' in frame.f_locals: - local_args = dict(frame.f_locals) - local_args.update(local_args.get('kwargs', {})) - local_args = {k: v for k, v in local_args.items() - if k not in ('args', 'kwargs', 'self', '__class__', 'frame', 'frame_args')} - # if 'hparams' in local_args: - # # back compatible hparams as single argument - # hparams = local_args.get('hparams') - # local_args.update(vars(hparams) if isinstance(hparams, Namespace) else hparams) + """ + Recursively collects the arguments passed to the child constructors in the inheritance tree. + + Args: + frame: the current stack frame + path_args: a list of dictionaries containing the constructor args in all parent classes + + Return: + A list of dictionaries where each dictionary contains the arguments passed to the + constructor at that level. The last entry corresponds to the constructor call of the + most specific class in the hierarchy. + """ + _, _, _, local_vars = inspect.getargvalues(frame) + if '__class__' in local_vars: + cls = local_vars['__class__'] + spec = inspect.getfullargspec(cls.__init__) + init_parameters = inspect.signature(cls.__init__).parameters + self_identifier = spec.args[0] # "self" unless user renames it (always first arg) + varargs_identifier = spec.varargs # by convention this is named "*args" + kwargs_identifier = spec.varkw # by convention this is named "**kwargs" + exclude_argnames = ( + varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args' + ) + + # only collect variables that appear in the signature + local_args = {k: local_vars[k] for k in init_parameters.keys()} + local_args.update(local_args.get(kwargs_identifier, {})) + local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} + # recursive update path_args.append(local_args) return _collect_init_args(frame.f_back, path_args) else: return path_args - - -def _get_latest_child(frame, child: object = None) -> object: - """Recursive search for lowest child.""" - if 'self' in frame.f_locals: - return _get_latest_child(frame.f_back, frame.f_locals['self']) - else: - return child diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index fa87ceeeb91c17..20999431f1bdcd 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -44,14 +44,8 @@ def test2(self): A().test() +@pytest.mark.skipif(sys.version_info < (3, 8), reason='OmegaConf only for Python >= 3.8') def test_omegaconf(tmpdir): - - # ogc only for 3.8 - major = sys.version_info[0] - minor = sys.version_info[1] - if major < 3 and minor < 8: - return - conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]}) model = OmegaConfModel(conf) @@ -73,6 +67,17 @@ def __init__(self, *args, subclass_arg=1200, **kwargs): self.auto_collect_arguments() +class UnconventionalArgsEvalModel(EvalModelTemplate): + """ A model that has unconventional names for "self", "*args" and "**kwargs". """ + + def __init__(obj, *more_args, other_arg=300, **more_kwargs): + # intentionally named obj + super().__init__(*more_args, **more_kwargs) + obj.other_arg = other_arg + other_arg = 321 + obj.auto_collect_arguments() + + class SubSubClassEvalModel(SubClassEvalModel): pass @@ -85,10 +90,13 @@ def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs): self.auto_collect_arguments() -@pytest.mark.parametrize("cls", [EvalModelTemplate, - SubClassEvalModel, - SubSubClassEvalModel, - AggSubClassEvalModel]) +@pytest.mark.parametrize("cls", [ + EvalModelTemplate, + SubClassEvalModel, + SubSubClassEvalModel, + AggSubClassEvalModel, + UnconventionalArgsEvalModel, +]) def test_collect_init_arguments(tmpdir, cls): """ Test that the model automatically saves the arguments passed into the constructor """ extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {} @@ -125,3 +133,36 @@ def test_collect_init_arguments(tmpdir, cls): # verify that we can overwrite whatever we want model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99) assert model.batch_size == 99 + + +class LocalVariableModel1(EvalModelTemplate): + """ This model has the super().__init__() call at the end. """ + + def __init__(self, arg1, arg2, *args, **kwargs): + self.argument1 = arg1 # arg2 intentionally not set + arg1 = 'overwritten' + local_var = 1234 + super().__init__(*args, **kwargs) # this is intentionally here at the end + + +class LocalVariableModel2(EvalModelTemplate): + """ This model has the auto_collect_arguments() call at the end. """ + + def __init__(self, arg1, arg2, *args, **kwargs): + super().__init__(*args, **kwargs) + self.argument1 = arg1 # arg2 intentionally not set + arg1 = 'overwritten' + local_var = 1234 + self.auto_collect_arguments() # this is intentionally here at the end + + +@pytest.mark.parametrize("cls", [ + LocalVariableModel1, + LocalVariableModel2, +]) +def test_collect_init_arguments_with_local_vars(cls): + """ Tests that only the arguments are collected and not local variables. """ + model = cls(arg1=1, arg2=2) + assert 'local_var' not in model.module_arguments + assert model.module_arguments['arg1'] == 'overwritten' + assert model.module_arguments['arg2'] == 2