diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 0fba5692b25d1..dd12f34cfe926 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -196,9 +196,11 @@ def __repr__(self): return out -def lightning_get_all_attr_holders(model, attribute): - """ Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute. - Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ +def _lightning_get_all_attr_holders(model, attribute): + """ + Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. + """ trainer = getattr(model, 'trainer', None) holders = [] @@ -219,13 +221,13 @@ def lightning_get_all_attr_holders(model, attribute): return holders -def lightning_get_first_attr_holder(model, attribute): +def _lightning_get_first_attr_holder(model, attribute): + """ + Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. + Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, + returns the last one that has it. """ - Special attribute finding for lightning. Gets the object or dict that holds attribute, or None. - Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, - returns the last one that has it. - """ - holders = lightning_get_all_attr_holders(model, attribute) + holders = _lightning_get_all_attr_holders(model, attribute) if len(holders) == 0: return None # using the last holder to preserve backwards compatibility @@ -233,17 +235,26 @@ def lightning_get_first_attr_holder(model, attribute): def lightning_hasattr(model, attribute): - """ Special hasattr for lightning. Checks for attribute in model namespace, - the old hparams namespace/dict, and the datamodule. """ - return lightning_get_first_attr_holder(model, attribute) is not None + """ + Special hasattr for Lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. + """ + return _lightning_get_first_attr_holder(model, attribute) is not None def lightning_getattr(model, attribute): - """ Special getattr for lightning. Checks for attribute in model namespace, - the old hparams namespace/dict, and the datamodule. """ - holder = lightning_get_first_attr_holder(model, attribute) + """ + Special getattr for Lightning. Checks for attribute in model namespace, + the old hparams namespace/dict, and the datamodule. + + Raises: + AttributeError: + If ``model`` doesn't have ``attribute`` in any of + model namespace, the hparams namespace/dict, and the datamodule. + """ + holder = _lightning_get_first_attr_holder(model, attribute) if holder is None: - raise ValueError( + raise AttributeError( f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.' ) @@ -254,13 +265,19 @@ def lightning_getattr(model, attribute): def lightning_setattr(model, attribute, value): - """ Special setattr for lightning. Checks for attribute in model namespace - and the old hparams namespace/dict. - Will also set the attribute on datamodule, if it exists. """ - holders = lightning_get_all_attr_holders(model, attribute) + Special setattr for Lightning. Checks for attribute in model namespace + and the old hparams namespace/dict. + Will also set the attribute on datamodule, if it exists. + + Raises: + AttributeError: + If ``model`` doesn't have ``attribute`` in any of + model namespace, the hparams namespace/dict, and the datamodule. + """ + holders = _lightning_get_all_attr_holders(model, attribute) if len(holders) == 0: - raise ValueError( + raise AttributeError( f'{attribute} is neither stored in the model namespace' ' nor the `hparams` namespace/dict, nor the datamodule.' ) diff --git a/tests/utilities/test_parsing.py b/tests/utilities/test_parsing.py index c07a016eda92d..42edb8e48f336 100644 --- a/tests/utilities/test_parsing.py +++ b/tests/utilities/test_parsing.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr @@ -74,8 +75,8 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat def test_lightning_hasattr(tmpdir): - """ Test that the lightning_hasattr works in all cases""" - model1, model2, model3, model4, model5, model6, model7 = _get_test_cases() + """Test that the lightning_hasattr works in all cases""" + model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases() assert lightning_hasattr(model1, 'learning_rate'), \ 'lightning_hasattr failed to find namespace variable' assert lightning_hasattr(model2, 'learning_rate'), \ @@ -91,9 +92,12 @@ def test_lightning_hasattr(tmpdir): assert lightning_hasattr(model7, 'batch_size'), \ 'lightning_hasattr failed to find batch_size in hparams w/ datamodule present' + for m in models: + assert not lightning_hasattr(m, "this_attr_not_exist") + def test_lightning_getattr(tmpdir): - """ Test that the lightning_getattr works in all cases""" + """Test that the lightning_getattr works in all cases""" models = _get_test_cases() for i, m in enumerate(models[:3]): value = lightning_getattr(m, 'learning_rate') @@ -107,9 +111,16 @@ def test_lightning_getattr(tmpdir): assert lightning_getattr(model7, 'batch_size') == 8, \ 'batch_size not correctly extracted' + for m in models: + with pytest.raises( + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + ): + lightning_getattr(m, "this_attr_not_exist") + def test_lightning_setattr(tmpdir): - """ Test that the lightning_setattr works in all cases""" + """Test that the lightning_setattr works in all cases""" models = _get_test_cases() for m in models[:3]: lightning_setattr(m, 'learning_rate', 10) @@ -126,3 +137,10 @@ def test_lightning_setattr(tmpdir): 'batch_size not correctly set' assert lightning_getattr(model7, 'batch_size') == 128, \ 'batch_size not correctly set' + + for m in models: + with pytest.raises( + AttributeError, + match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule." + ): + lightning_setattr(m, "this_attr_not_exist", None)