diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2c910edb8e404..ab198356f7279 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -333,7 +333,13 @@ def __hash__(self): hash_vals = [self.__class__.__name__] for key in self._defaults.keys(): - hash_vals.append(getattr(self, key)) + val = getattr(self, key) + # Special case: allow list values, so long + # as their elements are hashable + if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): + hash_vals.extend(val) + else: + hash_vals.append(val) return hash(tuple(hash_vals)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 03b79633e3eb7..04c7ce64beed0 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -154,6 +154,32 @@ def compute(self): assert a.compute() == 5 +def test_hash(): + + class A(Dummy): + pass + + class B(DummyList): + pass + + a1 = A() + a2 = A() + assert hash(a1) != hash(a2) + + b1 = B() + b2 = B() + assert hash(b1) == hash(b2) + assert isinstance(b1.x, list) and len(b1.x) == 0 + b1.x.append(torch.tensor(5)) + assert isinstance(hash(b1), int) # <- check that nothing crashes + assert isinstance(b1.x, list) and len(b1.x) == 1 + b2.x.append(torch.tensor(5)) + # Sanity: + assert isinstance(b2.x, list) and len(b2.x) == 1 + # Now that they have tensor contents, they should have different hashes: + assert hash(b1) != hash(b2) + + def test_forward(): class A(Dummy):