Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Allow hashing of metrics with lists in their state #5939

Merged
merged 20 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
fedce1c
Fix: Allow hashing of metrics with lists in their state
peblair Feb 11, 2021
ff74fd2
Add test case and modify semantics of Metric __hash__ in order to be …
peblair Feb 12, 2021
815ff28
Fix pep8 style issue
peblair Feb 12, 2021
e533693
Merge branch 'master' into fix-hashing-bug
Borda Feb 13, 2021
750926f
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
400c654
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
27395f2
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
8126bcc
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
e8fc885
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
63c5f78
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
7dd9635
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
c7969cf
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 15, 2021
a7e5d7e
Move local function to static method
peblair Feb 16, 2021
a42a335
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 16, 2021
13a48ff
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 16, 2021
1777b54
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 16, 2021
626d44f
Remove _hash_tensor method in order to un-break nn.Module.named_child…
peblair Feb 16, 2021
3327132
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 16, 2021
8a29a20
Merge branch 'master' into fix-hashing-bug
mergify[bot] Feb 16, 2021
688da6c
Fix problematic test, and test edge case surrounding hashing semantics
peblair Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,23 @@ def _filter_kwargs(self, **kwargs):
def __hash__(self):
hash_vals = [self.__class__.__name__]

# Torch tensors are hashable, but based on their
# underlying pointer. Since we do a structural
# equality check in __eq__, we circumvent this
# by hashing the _sum_ of tensors' contents
def format_tensor(t):
peblair marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(t, torch.Tensor):
return t
return float(t.detach().cpu().sum())

for key in self._defaults.keys():
hash_vals.append(getattr(self, key))
val = getattr(self, key)
# Special case: allow list values, so long
# as their elements are hashable
if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor):
hash_vals.extend((format_tensor(x) for x in val))
else:
hash_vals.append(format_tensor(val))

return hash(tuple(hash_vals))

Expand Down
21 changes: 21 additions & 0 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,27 @@ def compute(self):
assert a.compute() == 5


def test_hash():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this test be done directly on Metric objects ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton I apologize, could you elaborate? The hashing is being done on two subclasses of Metric, so I am unsure what precisely you mean.


class A(Dummy):
pass

class B(DummyList):
pass

a1 = A()
a2 = A()
assert hash(a1) == hash(a2)

b1 = B()
b2 = B()
assert hash(b1) == hash(b2)
assert isinstance(b1.x, list) and len(b1.x) == 0
b1.x.append(torch.tensor(5))
assert isinstance(hash(b1), int) # <- check that nothing crashes
assert isinstance(b1.x, list) and len(b1.x) == 1


def test_forward():

class A(Dummy):
Expand Down