Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Layerwise Sparsity Support for SparseGPT #1777

Merged
merged 12 commits into from
Oct 26, 2023
26 changes: 25 additions & 1 deletion src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class SparseGPTModifier(Modifier):
model.decoder for OPT or just model for Llama
"""

sparsity: float
sparsity: Union[float, List[float]]
block_size: int
quantize: bool
dampening_frac: Optional[float] = 0.01
Expand All @@ -61,3 +61,27 @@ class SparseGPTModifier(Modifier):

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.targets:
if layer_name.startswith("re:"):
Satrat marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
25 changes: 23 additions & 2 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ def compressible_layers(self) -> List[Module]:
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)

# Compare length to sparsities again in case one of the provided layers was
# invalid or not compressible
if isinstance(self.sparsity, List) and len(self.sparsity) != len(
compressible_dict
):
raise ValueError(
"Number of compressible layers must match the number of "
f"sparsities. Got {len(compressible_dict)} layers and "
f"{len(self.sparsity)} sparsities"
)

return [v for _, v in compressible_dict.items()]

def on_initialize(self, state: "State", **kwargs) -> bool:
Expand All @@ -65,6 +77,8 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparisity()

self.finalization_kwargs_ = {}
modifiable_model = state.model
calibration_dataloader = state.data.calib
Expand Down Expand Up @@ -125,10 +139,17 @@ def apply_obcq(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
_LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====")
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": self.sparsity,
"sparsity": layer_sparsity,
"prunen": self.prunen,
"prunem": self.prunem,
"blocksize": self.block_size,
Expand Down
8 changes: 6 additions & 2 deletions tests/sparseml/modifiers/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ def setup_modifier_factory():


class LifecyleTestingHarness:
def __init__(self, model=None, optimizer=None, framework=Framework.pytorch):
def __init__(
self, model=None, optimizer=None, framework=Framework.pytorch, device="cpu"
):
self.state = State(framework=framework)
self.state.update(model=model, optimizer=optimizer, start=0, steps_per_epoch=1)
self.state.update(
model=model, device=device, optimizer=optimizer, start=0, steps_per_epoch=1
)

self.event_lifecycle = CallbacksEventLifecycle(
type_first=EventType.BATCH_START, start=self.state.start_event
Expand Down
63 changes: 63 additions & 0 deletions tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sparseml.core.framework import Framework
from sparseml.core.model import ModifiableModel
from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import LinearNet


@pytest.mark.parametrize(
"sparsity,targets",
[
([0.5, 0.2], "__ALL__"), # type mismatch
([0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]), # length mismatch
([0.3, 0.4], ["re:.*fc1", "re:.*fc2"]), # regex not supported
],
)
def test_invalid_layerwise_recipes_raise_exceptions(sparsity, targets):
setup_modifier_factory()
model = LinearNet()

kwargs = dict(
sparsity=sparsity,
block_size=128,
quantize=False,
targets=targets,
)
modifier = SparseGPTModifierPyTorch(**kwargs)
testing_harness = LifecyleTestingHarness(model=model)

# confirm invalid layerwise recipes fail at initialization
with pytest.raises(ValueError):
modifier.initialize(testing_harness.get_state())


def test_successful_layerwise_recipe():
setup_modifier_factory()
model = LinearNet()

sparsities = [0.5, 0.2]
targets = ["seq.fc1", "seq.fc2"]
kwargs = dict(sparsity=sparsities, block_size=128, quantize=False, targets=targets)
modifier = SparseGPTModifierPyTorch(**kwargs)
modifier._validate_layerwise_sparisity()
modifier.model = ModifiableModel(framework=Framework.pytorch, model=model)
found_compressible_layers = modifier.compressible_layers()

# ensure layers names successfully match up with model
assert len(found_compressible_layers) == len(targets)
Loading