Skip to content

Commit

Permalink
Fix ModelPruning(make_pruning_permanent=True) buffers getting removed…
Browse files Browse the repository at this point in the history
… when saved during training (#6073)

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
2 people authored and lexierule committed Mar 9, 2021
1 parent b3b8f95 commit 68dd140
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 22 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.2.3] - 2021-03-09

### Added


### Changed


### Fixed

- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))



## [1.2.2] - 2021-03-02

### Added
Expand Down
43 changes: 27 additions & 16 deletions pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn.utils.prune as pytorch_prune
Expand All @@ -27,7 +27,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_debug, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_PYTORCH_PRUNING_FUNCTIONS = {
Expand Down Expand Up @@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
def _wrap_pruning_fn(pruning_fn, **kwargs):
return partial(pruning_fn, **kwargs)

def make_pruning_permanent(self):
""" Makes ``parameters_to_prune`` current pruning permanent. """
for module, param_name in self._parameters_to_prune:
try:
pytorch_prune.remove(module, param_name)
except ValueError:
# pruning already made permanent
pass
def make_pruning_permanent(self, pl_module: LightningModule):
"""
Removes pruning buffers from any pruned modules

Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
"""
for _, module in pl_module.named_modules():
for k in list(module._forward_pre_hooks):
hook = module._forward_pre_hooks[k]
if isinstance(hook, pytorch_prune.BasePruningMethod):
hook.remove(module)
del module._forward_pre_hooks[k]

def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
trained = getattr(module, tensor_name)
Expand Down Expand Up @@ -351,7 +355,7 @@ def _log_sparsity_stats(
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)

def on_before_accelerator_backend_setup(self, trainer, pl_module):
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
)
Expand All @@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module, *args):
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand All @@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
):
self.apply_lottery_ticket_hypothesis()

def on_train_end(self, *args):
def on_train_end(self, trainer, pl_module: LightningModule):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
self.make_pruning_permanent(pl_module)

def on_save_checkpoint(self, *args):
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
prev_device = pl_module.device
# prune a copy so training can continue with the same buffers
copy = deepcopy(pl_module.to("cpu"))
self.make_pruning_permanent(copy)
checkpoint["state_dict"] = copy.state_dict()
pl_module.to(prev_device)

@staticmethod
def sanitize_parameters_to_prune(
Expand Down
58 changes: 52 additions & 6 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import platform
from collections import OrderedDict
from logging import INFO
from unittest import mock

import pytest
import torch
Expand All @@ -24,7 +23,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning
from pytorch_lightning.callbacks import ModelCheckpoint, ModelPruning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand All @@ -42,6 +41,10 @@ def __init__(self):
])
)

def training_step(self, batch, batch_idx):
self.log("test", -batch_idx)
return super().training_step(batch, batch_idx)


class TestPruningMethod(pytorch_prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
Expand Down Expand Up @@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self):


@pytest.mark.parametrize("make_pruning_permanent", (False, True))
@mock.patch.dict(os.environ, {}, clear=True)
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
seed_everything(0)
model = TestModel()
Expand All @@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages[-9:]]
expected = [
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
Expand All @@ -256,11 +259,54 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
]
assert actual == expected

filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_from_checkpoint(filepath, strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning


def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
"""
When a model is saved multiple times and make_permanent=True, we need to
make sure a copy is pruned and not the trained model if we want to continue
with the same pruning buffers.
"""
seed_everything(0)

class TestPruning(ModelPruning):

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
assert hasattr(pl_module.layer.mlp_3, "weight_orig")

model = TestModel()
pruning_callback = TestPruning(
"random_unstructured",
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True
)
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
]

# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")

model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")

0 comments on commit 68dd140

Please sign in to comment.