From fedce1ce31d27d598cfafb443a65c6f083a72b90 Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Thu, 11 Feb 2021 16:44:29 +0100 Subject: [PATCH 1/6] Fix: Allow hashing of metrics with lists in their state --- pytorch_lightning/metrics/metric.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 2c910edb8e404..a4e64090ca99c 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -333,7 +333,11 @@ def __hash__(self): hash_vals = [self.__class__.__name__] for key in self._defaults.keys(): - hash_vals.append(getattr(self, key)) + val = getattr(self, key) + # Special case: allow list values, so long as their elements are hashable + if isinstance(val, list): + val = tuple(val) + hash_vals.append(val) return hash(tuple(hash_vals)) From ff74fd20adcf1e9ab76d00bd835ce218a089cd20 Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Fri, 12 Feb 2021 18:12:16 +0100 Subject: [PATCH 2/6] Add test case and modify semantics of Metric __hash__ in order to be compatible with structural equality checks --- pytorch_lightning/metrics/metric.py | 19 +++++++++++++++---- tests/metrics/test_metric.py | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index a4e64090ca99c..8091a6305db5e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -332,12 +332,23 @@ def _filter_kwargs(self, **kwargs): def __hash__(self): hash_vals = [self.__class__.__name__] + # Torch tensors are hashable, but based on their + # underlying pointer. Since we do a structural + # equality check in __eq__, we circumvent this + # by hashing the _sum_ of tensors' contents + def format_tensor(t): + if not isinstance(t, torch.Tensor): + return t + return float(t.detach().cpu().sum()) + for key in self._defaults.keys(): val = getattr(self, key) - # Special case: allow list values, so long as their elements are hashable - if isinstance(val, list): - val = tuple(val) - hash_vals.append(val) + # Special case: allow list values, so long + # as their elements are hashable + if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): + hash_vals.extend((format_tensor(x) for x in val)) + else: + hash_vals.append(format_tensor(val)) return hash(tuple(hash_vals)) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 03b79633e3eb7..e397f982cd3df 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -154,6 +154,27 @@ def compute(self): assert a.compute() == 5 +def test_hash(): + + class A(Dummy): + pass + + class B(DummyList): + pass + + a1 = A() + a2 = A() + assert hash(a1) == hash(a2) + + b1 = B() + b2 = B() + assert hash(b1) == hash(b2) + assert isinstance(b1.x, list) and len(b1.x) == 0 + b1.x.append(torch.tensor(5)) + assert isinstance(hash(b1), int) # <- check that nothing crashes + assert isinstance(b1.x, list) and len(b1.x) == 1 + + def test_forward(): class A(Dummy): From 815ff2879a2ee950b2222afca922cb01f2cfcd3c Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Fri, 12 Feb 2021 18:14:02 +0100 Subject: [PATCH 3/6] Fix pep8 style issue --- tests/metrics/test_metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index e397f982cd3df..0851fc1502772 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -171,7 +171,7 @@ class B(DummyList): assert hash(b1) == hash(b2) assert isinstance(b1.x, list) and len(b1.x) == 0 b1.x.append(torch.tensor(5)) - assert isinstance(hash(b1), int) # <- check that nothing crashes + assert isinstance(hash(b1), int) # <- check that nothing crashes assert isinstance(b1.x, list) and len(b1.x) == 1 From a7e5d7e4c8498d8a39fd87985a336bbebde211a8 Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Tue, 16 Feb 2021 10:06:17 +0100 Subject: [PATCH 4/6] Move local function to static method --- pytorch_lightning/metrics/metric.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 8091a6305db5e..88f13150d6d30 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -329,26 +329,29 @@ def _filter_kwargs(self, **kwargs): filtered_kwargs = kwargs return filtered_kwargs + @staticmethod + def _hash_tensor(t): + """ + Torch tensors are hashable, but based on their + underlying pointer. Since we do a structural + equality check in __eq__, we circumvent this + by hashing the _sum_ of tensors' contents. + """ + if not isinstance(t, torch.Tensor): + return t + return float(t.detach().cpu().sum()) + def __hash__(self): hash_vals = [self.__class__.__name__] - # Torch tensors are hashable, but based on their - # underlying pointer. Since we do a structural - # equality check in __eq__, we circumvent this - # by hashing the _sum_ of tensors' contents - def format_tensor(t): - if not isinstance(t, torch.Tensor): - return t - return float(t.detach().cpu().sum()) - for key in self._defaults.keys(): val = getattr(self, key) # Special case: allow list values, so long # as their elements are hashable if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend((format_tensor(x) for x in val)) + hash_vals.extend((self._hash_tensor(x) for x in val)) else: - hash_vals.append(format_tensor(val)) + hash_vals.append(self._hash_tensor(val)) return hash(tuple(hash_vals)) From 626d44fd45ae8a9d9b5c34fc49231251ba718c3e Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Tue, 16 Feb 2021 16:16:07 +0100 Subject: [PATCH 5/6] Remove _hash_tensor method in order to un-break nn.Module.named_children (for builtin metrics, at least) --- pytorch_lightning/metrics/metric.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 88f13150d6d30..ab198356f7279 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -329,18 +329,6 @@ def _filter_kwargs(self, **kwargs): filtered_kwargs = kwargs return filtered_kwargs - @staticmethod - def _hash_tensor(t): - """ - Torch tensors are hashable, but based on their - underlying pointer. Since we do a structural - equality check in __eq__, we circumvent this - by hashing the _sum_ of tensors' contents. - """ - if not isinstance(t, torch.Tensor): - return t - return float(t.detach().cpu().sum()) - def __hash__(self): hash_vals = [self.__class__.__name__] @@ -349,9 +337,9 @@ def __hash__(self): # Special case: allow list values, so long # as their elements are hashable if hasattr(val, '__iter__') and not isinstance(val, torch.Tensor): - hash_vals.extend((self._hash_tensor(x) for x in val)) + hash_vals.extend(val) else: - hash_vals.append(self._hash_tensor(val)) + hash_vals.append(val) return hash(tuple(hash_vals)) From 688da6c2b2c57f64e1b9011b79e3347e52ed9bad Mon Sep 17 00:00:00 2001 From: Philip Blair Date: Thu, 18 Feb 2021 10:26:19 +0100 Subject: [PATCH 6/6] Fix problematic test, and test edge case surrounding hashing semantics --- tests/metrics/test_metric.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 0851fc1502772..04c7ce64beed0 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -164,7 +164,7 @@ class B(DummyList): a1 = A() a2 = A() - assert hash(a1) == hash(a2) + assert hash(a1) != hash(a2) b1 = B() b2 = B() @@ -173,6 +173,11 @@ class B(DummyList): b1.x.append(torch.tensor(5)) assert isinstance(hash(b1), int) # <- check that nothing crashes assert isinstance(b1.x, list) and len(b1.x) == 1 + b2.x.append(torch.tensor(5)) + # Sanity: + assert isinstance(b2.x, list) and len(b2.x) == 1 + # Now that they have tensor contents, they should have different hashes: + assert hash(b1) != hash(b2) def test_forward():