From 127b36b1ab2f68b4354568487c302153648770df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 01:25:33 +0200 Subject: [PATCH 1/6] detach hooks after completion --- pytorch_lightning/core/memory.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 738184b9c091a..c090c5a9d85be 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -64,9 +64,11 @@ def _register_hook(self): def hook(module, inp, out): if len(inp) == 1: inp = inp[0] - self._in_size = parse_batch_shape(inp) - self._out_size = parse_batch_shape(out) - self._hook_handle.remove() # hook detaches itself from module + if self._in_size == UNKNOWN_SIZE: + self._in_size = parse_batch_shape(inp) + if self._out_size == UNKNOWN_SIZE: + self._out_size = parse_batch_shape(out) + #self._hook_handle.remove() # hook detaches itself from module return self._module.register_forward_hook(hook) @@ -180,6 +182,8 @@ def summarize(self) -> Dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() + for layer in summary.values(): + layer._hook_handle.remove() return summary def _forward_example_input(self) -> None: From d98e62bf101600dc6e7898f488f0b7b37538b736 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 10:09:57 +0200 Subject: [PATCH 2/6] detach hook --- pytorch_lightning/core/memory.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index c090c5a9d85be..fe4ed90e6fb91 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -54,24 +54,31 @@ def __init__(self, module: nn.Module): self._in_size = None self._out_size = None + def __del__(self): + self.detach_hook() + def _register_hook(self): """ - Registers a hook on the module that computes the input- and output size(s) - on the first forward pass. The hook will remove itself from the module, meaning that - recursive models will only record their input- and output shapes once. + Registers a hook on the module that computes the input- and output size(s) on the first forward pass. + Recursive models will only record their input- and output shapes once. """ def hook(module, inp, out): if len(inp) == 1: inp = inp[0] - if self._in_size == UNKNOWN_SIZE: + if self._in_size in (None, UNKNOWN_SIZE): self._in_size = parse_batch_shape(inp) - if self._out_size == UNKNOWN_SIZE: + if self._out_size in (None, UNKNOWN_SIZE): self._out_size = parse_batch_shape(out) - #self._hook_handle.remove() # hook detaches itself from module + self._hook_handle.remove() return self._module.register_forward_hook(hook) + def detach_hook(self): + """ Removes the forward hook. Will be called after the summary is created. """ + if self._hook_handle is not None: + self._hook_handle.remove() + @property def in_size(self): return self._in_size or UNKNOWN_SIZE @@ -183,7 +190,7 @@ def summarize(self) -> Dict[str, LayerSummary]: if self._model.example_input_array is not None: self._forward_example_input() for layer in summary.values(): - layer._hook_handle.remove() + layer.detach_hook() return summary def _forward_example_input(self) -> None: From 4ea93d5d5f05b33eae3b766ff5fe77fdcb01a1e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 10:17:40 +0200 Subject: [PATCH 3/6] update docs --- pytorch_lightning/core/memory.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index fe4ed90e6fb91..d334d3ff1857c 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -60,22 +60,24 @@ def __del__(self): def _register_hook(self): """ Registers a hook on the module that computes the input- and output size(s) on the first forward pass. - Recursive models will only record their input- and output shapes once. + If the hook is called, it will remove itself from the from the module, meaning that + recursive models will only record their input- and output shapes once. """ def hook(module, inp, out): if len(inp) == 1: inp = inp[0] - if self._in_size in (None, UNKNOWN_SIZE): - self._in_size = parse_batch_shape(inp) - if self._out_size in (None, UNKNOWN_SIZE): - self._out_size = parse_batch_shape(out) + self._in_size = parse_batch_shape(inp) + self._out_size = parse_batch_shape(out) self._hook_handle.remove() return self._module.register_forward_hook(hook) def detach_hook(self): - """ Removes the forward hook. Will be called after the summary is created. """ + """ + Removes the forward hook if it was not already removed in the forward pass. + Will be called after the summary is created. + """ if self._hook_handle is not None: self._hook_handle.remove() From 3a9655510506540555062166df0d9bb62ae2dcf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 10:36:55 +0200 Subject: [PATCH 4/6] add test --- tests/core/test_memory.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 6ddbb18ecf548..6eab2c31319f6 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -89,6 +89,20 @@ def test_linear_model_summary_shapes(device, dtype, mode): assert model.device == device +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_hooks_removed_after_summarize(mode): + """ Test that all hooks were properly removed after summary, even ones that were not run. """ + model = UnorderedModel() + summary = ModelSummary(model, mode=mode) + # hooks should be removed + for _, layer in summary.summarize().items(): + handle = layer._hook_handle + assert handle.id not in handle.hooks_dict_ref() + + @pytest.mark.parametrize(['mode'], [ pytest.param(ModelSummary.MODE_FULL), pytest.param(ModelSummary.MODE_TOP), From 4953f5e241e7b92681d87b62489a871691842fd0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 10:37:05 +0200 Subject: [PATCH 5/6] docs --- pytorch_lightning/core/memory.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index d334d3ff1857c..aa3956b0d6d7d 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn as nn +from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -57,11 +58,14 @@ def __init__(self, module: nn.Module): def __del__(self): self.detach_hook() - def _register_hook(self): + def _register_hook(self) -> RemovableHandle: """ Registers a hook on the module that computes the input- and output size(s) on the first forward pass. If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. + + Return: + A handle for the installed hook. """ def hook(module, inp, out): @@ -82,11 +86,11 @@ def detach_hook(self): self._hook_handle.remove() @property - def in_size(self): + def in_size(self) -> Union[str, List]: return self._in_size or UNKNOWN_SIZE @property - def out_size(self): + def out_size(self) -> Union[str, List]: return self._out_size or UNKNOWN_SIZE @property From 8d7076a23e745998ef4b8781f65c0c82167d672e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Jun 2020 10:51:30 +0200 Subject: [PATCH 6/6] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2649abc02d54e..aaa6d34250fd4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) + ## [0.8.1] - 2020-06-19