diff --git a/CHANGELOG.md b/CHANGELOG.md index 51e56e7c1b18b..10d857f2a5a51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298)) + ## [0.8.1] - 2020-06-19 diff --git a/pytorch_lightning/core/memory.py b/pytorch_lightning/core/memory.py index 738184b9c091a..aa3956b0d6d7d 100644 --- a/pytorch_lightning/core/memory.py +++ b/pytorch_lightning/core/memory.py @@ -7,6 +7,7 @@ import numpy as np import torch import torch.nn as nn +from torch.utils.hooks import RemovableHandle import pytorch_lightning as pl from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -54,11 +55,17 @@ def __init__(self, module: nn.Module): self._in_size = None self._out_size = None - def _register_hook(self): + def __del__(self): + self.detach_hook() + + def _register_hook(self) -> RemovableHandle: """ - Registers a hook on the module that computes the input- and output size(s) - on the first forward pass. The hook will remove itself from the module, meaning that + Registers a hook on the module that computes the input- and output size(s) on the first forward pass. + If the hook is called, it will remove itself from the from the module, meaning that recursive models will only record their input- and output shapes once. + + Return: + A handle for the installed hook. """ def hook(module, inp, out): @@ -66,16 +73,24 @@ def hook(module, inp, out): inp = inp[0] self._in_size = parse_batch_shape(inp) self._out_size = parse_batch_shape(out) - self._hook_handle.remove() # hook detaches itself from module + self._hook_handle.remove() return self._module.register_forward_hook(hook) + def detach_hook(self): + """ + Removes the forward hook if it was not already removed in the forward pass. + Will be called after the summary is created. + """ + if self._hook_handle is not None: + self._hook_handle.remove() + @property - def in_size(self): + def in_size(self) -> Union[str, List]: return self._in_size or UNKNOWN_SIZE @property - def out_size(self): + def out_size(self) -> Union[str, List]: return self._out_size or UNKNOWN_SIZE @property @@ -180,6 +195,8 @@ def summarize(self) -> Dict[str, LayerSummary]: summary = OrderedDict((name, LayerSummary(module)) for name, module in self.named_modules) if self._model.example_input_array is not None: self._forward_example_input() + for layer in summary.values(): + layer.detach_hook() return summary def _forward_example_input(self) -> None: diff --git a/tests/core/test_memory.py b/tests/core/test_memory.py index 6ddbb18ecf548..6eab2c31319f6 100644 --- a/tests/core/test_memory.py +++ b/tests/core/test_memory.py @@ -89,6 +89,20 @@ def test_linear_model_summary_shapes(device, dtype, mode): assert model.device == device +@pytest.mark.parametrize(['mode'], [ + pytest.param(ModelSummary.MODE_FULL), + pytest.param(ModelSummary.MODE_TOP), +]) +def test_hooks_removed_after_summarize(mode): + """ Test that all hooks were properly removed after summary, even ones that were not run. """ + model = UnorderedModel() + summary = ModelSummary(model, mode=mode) + # hooks should be removed + for _, layer in summary.summarize().items(): + handle = layer._hook_handle + assert handle.id not in handle.hooks_dict_ref() + + @pytest.mark.parametrize(['mode'], [ pytest.param(ModelSummary.MODE_FULL), pytest.param(ModelSummary.MODE_TOP),