From 77f6aa43e51434f8f510f60e884fc7457320e99a Mon Sep 17 00:00:00 2001 From: Philip E Blair Date: Thu, 18 Feb 2021 10:54:12 +0100 Subject: [PATCH] Fix: Allow hashing of metrics with lists in their state (#5939) * Fix: Allow hashing of metrics with lists in their state * Add test case and modify semantics of Metric __hash__ in order to be compatible with structural equality checks * Fix pep8 style issue Co-authored-by: Jirka Borovec Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- pytorch_lightning/metrics/metric.py | 8 +++++++- tests/metrics/test_metric.py | 26 ++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2c910edb8e404..ab198356f7279 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -333,7 +333,13 @@ def __hash__(self): hash_vals = [self.__class__.__name__] for key in self._defaults.keys(): - hash_vals.append(getattr(self, key)) + val = getattr(self, key) + # Special case: allow list values, so long + # as their elements are hashable + if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): + hash_vals.extend(val) + else: + hash_vals.append(val) return hash(tuple(hash_vals)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 03b79633e3eb7..04c7ce64beed0 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -154,6 +154,32 @@ def compute(self): assert a.compute() == 5 +def test_hash(): + + class A(Dummy): + pass + + class B(DummyList): + pass + + a1 = A() + a2 = A() + assert hash(a1) != hash(a2) + + b1 = B() + b2 = B() + assert hash(b1) == hash(b2) + assert isinstance(b1.x, list) and len(b1.x) == 0 + b1.x.append(torch.tensor(5)) + assert isinstance(hash(b1), int) # <- check that nothing crashes + assert isinstance(b1.x, list) and len(b1.x) == 1 + b2.x.append(torch.tensor(5)) + # Sanity: + assert isinstance(b2.x, list) and len(b2.x) == 1 + # Now that they have tensor contents, they should have different hashes: + assert hash(b1) != hash(b2) + + def test_forward(): class A(Dummy):