From bd8b25d1665f411b72795839569196af6956a6fa Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 21:43:12 -0400 Subject: [PATCH 1/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index df163aaa1097f..9d189d843de54 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1692,4 +1692,4 @@ def hparams(self) -> Union[AttributeDict, str]: @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): - self.save_hyperparameters(hp, frame=inspect.currentframe().f_back.f_back) + self._set_hparams(hp) From b8410e9f1cfea97a1e9a2bb250cfdc38177e99d3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 22:00:24 -0400 Subject: [PATCH 2/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9d189d843de54..0595f4062c27f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1692,4 +1692,4 @@ def hparams(self) -> Union[AttributeDict, str]: @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): - self._set_hparams(hp) + self._hparams = hp From 7a37ea5a9860c65081842687b3183b8c5144948a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 22:00:41 -0400 Subject: [PATCH 3/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 0595f4062c27f..9d189d843de54 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1692,4 +1692,4 @@ def hparams(self) -> Union[AttributeDict, str]: @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): - self._hparams = hp + self._set_hparams(hp) From 910b377577cb8e64daf68b4da6005878a969a289 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 22:34:59 -0400 Subject: [PATCH 4/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 14 ++++++++++++++ pytorch_lightning/core/saving.py | 4 ++++ tests/models/test_hparams.py | 2 +- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 9d189d843de54..308dfcb39a0a4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1,6 +1,7 @@ import collections import inspect import os +import re from abc import ABC, abstractmethod from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence @@ -1692,4 +1693,17 @@ def hparams(self) -> Union[AttributeDict, str]: @hparams.setter def hparams(self, hp: Union[dict, Namespace, Any]): + hparams_assignment_name = self.__get_hparams_assignment_variable() + self._hparams_name = hparams_assignment_name self._set_hparams(hp) + + def __get_hparams_assignment_variable(self): + class_code = inspect.getsource(self.__class__) + lines = class_code.split('\n') + for line in lines: + line = re.sub(r"\s+", "", line, flags=re.UNICODE) + if 'self.hparams=' in line: + return line.split('=')[1] + + return None + diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 453b163a02a35..e6c7b0e222274 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -176,14 +176,18 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs): # pass in the values we saved automatically if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: model_args = {} + # add some back compatibility, the actual one shall be last for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,): if hparam_key in checkpoint: model_args.update(checkpoint[hparam_key]) + if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint: model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args) + args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) init_args_name = inspect.signature(cls).parameters.keys() + if args_name == 'kwargs': cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name} kwargs.update(**cls_kwargs) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index b802f894f48f9..7c85e734f9484 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -39,7 +39,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): assert model.hparams.test_arg == 14 # verify we can train - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=0.5) + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) trainer.fit(model) # make sure the raw checkpoint saved the properties From af8716815ae9b7e29ca030d0785042d0e7bf7c09 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 22:35:26 -0400 Subject: [PATCH 5/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 308dfcb39a0a4..c175ad7b5b13c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1706,4 +1706,3 @@ def __get_hparams_assignment_variable(self): return line.split('=')[1] return None - From 3325aba4e5fbcad7c4d15256a8f40b9e3f1e790f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 18 Jun 2020 22:36:40 -0400 Subject: [PATCH 6/6] remove frame inspection on self.hparams --- pytorch_lightning/core/lightning.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c175ad7b5b13c..d2190a37f5e9a 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1698,6 +1698,10 @@ def hparams(self, hp: Union[dict, Namespace, Any]): self._set_hparams(hp) def __get_hparams_assignment_variable(self): + """ + looks at the code of the class to figure out what the user named self.hparams + this only happens when the user explicitly sets self.hparams + """ class_code = inspect.getsource(self.__class__) lines = class_code.split('\n') for line in lines: