Skip to content

Commit

Permalink
Raise AttributeError in lightning_getattr and lightning_setattr when …
Browse files Browse the repository at this point in the history
…attribute not found (#6024)

* Empty commit

* Raise AttributeError instead of ValueError

* Make functions private

* Update tests

* Add match string

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* lightning to Lightning

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
akihironitta and awaelchli authored Feb 18, 2021
1 parent b0074a4 commit 8f82823
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
59 changes: 38 additions & 21 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def __repr__(self):
return out


def lightning_get_all_attr_holders(model, attribute):
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
def _lightning_get_all_attr_holders(model, attribute):
"""
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
"""
trainer = getattr(model, 'trainer', None)

holders = []
Expand All @@ -219,31 +221,40 @@ def lightning_get_all_attr_holders(model, attribute):
return holders


def lightning_get_first_attr_holder(model, attribute):
def _lightning_get_first_attr_holder(model, attribute):
"""
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
returns the last one that has it.
"""
Special attribute finding for lightning. Gets the object or dict that holds attribute, or None.
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
returns the last one that has it.
"""
holders = lightning_get_all_attr_holders(model, attribute)
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
return None
# using the last holder to preserve backwards compatibility
return holders[-1]


def lightning_hasattr(model, attribute):
""" Special hasattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
return lightning_get_first_attr_holder(model, attribute) is not None
"""
Special hasattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
"""
return _lightning_get_first_attr_holder(model, attribute) is not None


def lightning_getattr(model, attribute):
""" Special getattr for lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule. """
holder = lightning_get_first_attr_holder(model, attribute)
"""
Special getattr for Lightning. Checks for attribute in model namespace,
the old hparams namespace/dict, and the datamodule.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
"""
holder = _lightning_get_first_attr_holder(model, attribute)
if holder is None:
raise ValueError(
raise AttributeError(
f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.'
)
Expand All @@ -254,13 +265,19 @@ def lightning_getattr(model, attribute):


def lightning_setattr(model, attribute, value):
""" Special setattr for lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
"""
holders = lightning_get_all_attr_holders(model, attribute)
Special setattr for Lightning. Checks for attribute in model namespace
and the old hparams namespace/dict.
Will also set the attribute on datamodule, if it exists.
Raises:
AttributeError:
If ``model`` doesn't have ``attribute`` in any of
model namespace, the hparams namespace/dict, and the datamodule.
"""
holders = _lightning_get_all_attr_holders(model, attribute)
if len(holders) == 0:
raise ValueError(
raise AttributeError(
f'{attribute} is neither stored in the model namespace'
' nor the `hparams` namespace/dict, nor the datamodule.'
)
Expand Down
26 changes: 22 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr

Expand Down Expand Up @@ -74,8 +75,8 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat


def test_lightning_hasattr(tmpdir):
""" Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
"""Test that the lightning_hasattr works in all cases"""
model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases()
assert lightning_hasattr(model1, 'learning_rate'), \
'lightning_hasattr failed to find namespace variable'
assert lightning_hasattr(model2, 'learning_rate'), \
Expand All @@ -91,9 +92,12 @@ def test_lightning_hasattr(tmpdir):
assert lightning_hasattr(model7, 'batch_size'), \
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'

for m in models:
assert not lightning_hasattr(m, "this_attr_not_exist")


def test_lightning_getattr(tmpdir):
""" Test that the lightning_getattr works in all cases"""
"""Test that the lightning_getattr works in all cases"""
models = _get_test_cases()
for i, m in enumerate(models[:3]):
value = lightning_getattr(m, 'learning_rate')
Expand All @@ -107,9 +111,16 @@ def test_lightning_getattr(tmpdir):
assert lightning_getattr(model7, 'batch_size') == 8, \
'batch_size not correctly extracted'

for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_getattr(m, "this_attr_not_exist")


def test_lightning_setattr(tmpdir):
""" Test that the lightning_setattr works in all cases"""
"""Test that the lightning_setattr works in all cases"""
models = _get_test_cases()
for m in models[:3]:
lightning_setattr(m, 'learning_rate', 10)
Expand All @@ -126,3 +137,10 @@ def test_lightning_setattr(tmpdir):
'batch_size not correctly set'
assert lightning_getattr(model7, 'batch_size') == 128, \
'batch_size not correctly set'

for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_setattr(m, "this_attr_not_exist", None)

0 comments on commit 8f82823

Please sign in to comment.