diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f31000b0c302..1f6ec645dbd78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197)) +- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073)) + + - Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216)) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index ee130a700ae68..f6acb1c4355f1 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -18,7 +18,7 @@ import inspect from copy import deepcopy from functools import partial -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union, Dict import torch import torch.nn.utils.prune as pytorch_prune @@ -27,7 +27,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug from pytorch_lightning.utilities.exceptions import MisconfigurationException _PYTORCH_PRUNING_FUNCTIONS = { @@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor def _wrap_pruning_fn(pruning_fn, **kwargs): return partial(pruning_fn, **kwargs) - def make_pruning_permanent(self): - """ Makes ``parameters_to_prune`` current pruning permanent. """ - for module, param_name in self._parameters_to_prune: - try: - pytorch_prune.remove(module, param_name) - except ValueError: - # pruning already made permanent - pass + def make_pruning_permanent(self, pl_module: LightningModule): + """ + Removes pruning buffers from any pruned modules + + Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 + """ + for _, module in pl_module.named_modules(): + for k in list(module._forward_pre_hooks): + hook = module._forward_pre_hooks[k] + if isinstance(hook, pytorch_prune.BasePruningMethod): + hook.remove(module) + del module._forward_pre_hooks[k] def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str): trained = getattr(module, tensor_name) @@ -351,7 +355,7 @@ def _log_sparsity_stats( f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})" ) - def on_before_accelerator_backend_setup(self, trainer, pl_module): + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule): parameters_to_prune = self.sanitize_parameters_to_prune( pl_module, self._parameters_to_prune, parameter_names=self._parameter_names ) @@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) self._original_layers[id_]["names"].append((i, name)) - def on_train_epoch_end(self, trainer, pl_module, *args): + def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs): current_epoch = trainer.current_epoch prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount @@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args): ): self.apply_lottery_ticket_hypothesis() - def on_train_end(self, *args): + def on_train_end(self, trainer, pl_module: LightningModule): if self._make_pruning_permanent: - self.make_pruning_permanent() + rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.") + self.make_pruning_permanent(pl_module) - def on_save_checkpoint(self, *args): + def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]): if self._make_pruning_permanent: - self.make_pruning_permanent() + rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.") + prev_device = pl_module.device + # prune a copy so training can continue with the same buffers + copy = deepcopy(pl_module.to("cpu")) + self.make_pruning_permanent(copy) + checkpoint["state_dict"] = copy.state_dict() + pl_module.to(prev_device) @staticmethod def sanitize_parameters_to_prune( diff --git a/tests/callbacks/test_pruning.py b/tests/callbacks/test_pruning.py index 62b0d3a8f3bb3..4915e193c19ae 100644 --- a/tests/callbacks/test_pruning.py +++ b/tests/callbacks/test_pruning.py @@ -15,7 +15,6 @@ import platform from collections import OrderedDict from logging import INFO -from unittest import mock import pytest import torch @@ -24,7 +23,7 @@ from torch.nn import Sequential from pytorch_lightning import seed_everything, Trainer -from pytorch_lightning.callbacks import ModelPruning +from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel @@ -42,6 +41,10 @@ def __init__(self): ]) ) + def training_step(self, batch, batch_idx): + self.log("test", -batch_idx) + return super().training_step(batch, batch_idx) + class TestPruningMethod(pytorch_prune.BasePruningMethod): PRUNING_TYPE = "unstructured" @@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self): @pytest.mark.parametrize("make_pruning_permanent", (False, True)) -@mock.patch.dict(os.environ, {}, clear=True) def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): seed_everything(0) model = TestModel() @@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): with caplog.at_level(INFO): trainer.fit(model) - actual = [m.strip() for m in caplog.messages[-9:]] - expected = [ + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ "Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)", "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501 "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501 @@ -256,7 +259,6 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): "Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501 "Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501 ] - assert actual == expected filepath = str(tmpdir / "foo.ckpt") trainer.save_checkpoint(filepath) @@ -264,3 +266,46 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent): model.load_from_checkpoint(filepath, strict=False) has_pruning = hasattr(model.layer.mlp_1, "weight_orig") assert not has_pruning if make_pruning_permanent else has_pruning + + +def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog): + """ + When a model is saved multiple times and make_permanent=True, we need to + make sure a copy is pruned and not the trained model if we want to continue + with the same pruning buffers. + """ + seed_everything(0) + + class TestPruning(ModelPruning): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): + super().on_save_checkpoint(trainer, pl_module, checkpoint) + assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"] + assert hasattr(pl_module.layer.mlp_3, "weight_orig") + + model = TestModel() + pruning_callback = TestPruning( + "random_unstructured", + parameters_to_prune=[(model.layer.mlp_3, "weight")], + verbose=1, + make_pruning_permanent=True + ) + ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True) + trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0) + with caplog.at_level(INFO): + trainer.fit(model) + + actual = [m.strip() for m in caplog.messages] + actual = [m for m in actual if m.startswith("Applied")] + assert actual == [ + "Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)", + "Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)", + "Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)", + ] + + # removed on_train_end + assert not hasattr(model.layer.mlp_3, "weight_orig") + + model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig") + model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) + assert not hasattr(model.layer.mlp_3, "weight_orig")