Skip to content

Commit

Permalink
Layerwise Sparsity Support for SparseGPT (#1777)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 26, 2023
1 parent 916657c commit e20927f
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 48 deletions.
31 changes: 27 additions & 4 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class SparseGPTModifier(Modifier):
model.decoder for OPT or just model for Llama
"""

sparsity: float
sparsity: Union[float, List[float]]
block_size: int
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
Expand All @@ -63,7 +63,6 @@ class SparseGPTModifier(Modifier):
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None

compressible_layers_: List = None
quantization_modifier_: Any = None

Expand All @@ -76,7 +75,7 @@ def compressible_layers(self) -> List:
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def pre_initialize_structure(self, state: State, **kwargs):
def on_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
Expand Down Expand Up @@ -123,7 +122,7 @@ def pre_initialize_structure(self, state: State, **kwargs):
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.pre_initialize_structure(state, **kwargs)
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
Expand All @@ -135,3 +134,27 @@ def _build_quant_modifier_from_dict(self, quant_config, framework):
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
15 changes: 12 additions & 3 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparisity()

if not self.initialized_structure_:
self.pre_initialize_structure(state, **kwargs)
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
self.finalization_kwargs_ = {}
Expand Down Expand Up @@ -118,10 +120,17 @@ def apply_obcq(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
_LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====")
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": self.sparsity,
"sparsity": layer_sparsity,
"prunen": self.prunen,
"prunem": self.prunem,
"blocksize": self.block_size,
Expand Down
39 changes: 3 additions & 36 deletions src/sparseml/modifiers/obcq/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import Dict, List

import torch
import torch.nn as nn
from torch.nn import Module

from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT
from sparseml.pytorch.utils.helpers import get_dependency_order
from sparseml.utils.pytorch.module import get_prunable_layers


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,14 +60,8 @@ def compressible_modules(self) -> Dict:
:return: dictionary of compressible modules
"""
quantize = self.args.get("quantize", False)
if quantize:
# The layer names are changed due to quantization modifiers, therefore
# we need a slightly different func to retrieve layers
modules = _find_quant_layers(self.layer)
else:
modules = _find_layers(self.layer)
return modules
compressible_layers = get_prunable_layers(self.layer)
return compressible_layers

def pre_compress_parallel(self, **kwargs) -> Dict:
"""
Expand Down Expand Up @@ -217,30 +211,3 @@ def tmp(_, inp, out):
blocksize=self.args["blocksize"],
)
gpts.free()


def _find_quant_layers(
module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name=""
):
res = {}
# search for QAT versions of layers
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res


def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
8 changes: 6 additions & 2 deletions tests/sparseml/modifiers/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ def setup_modifier_factory():


class LifecyleTestingHarness:
def __init__(self, model=None, optimizer=None, framework=Framework.pytorch):
def __init__(
self, model=None, optimizer=None, framework=Framework.pytorch, device="cpu"
):
self.state = State(framework=framework)
self.state.update(model=model, optimizer=optimizer, start=0, steps_per_epoch=1)
self.state.update(
model=model, device=device, optimizer=optimizer, start=0, steps_per_epoch=1
)

self.event_lifecycle = CallbacksEventLifecycle(
type_first=EventType.BATCH_START, start=self.state.start_event
Expand Down
52 changes: 49 additions & 3 deletions tests/sparseml/pytorch/modifiers/obcq/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

from sparseml.core.framework import Framework
from sparseml.core.model import ModifiableModel
from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch
from sparseml.modifiers.quantization import QuantizationModifier
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import LinearNet


@pytest.mark.parametrize(
"sparsity,targets",
[
([0.5, 0.2], "__ALL__"), # type mismatch
([0.2, 0.1, 0.3], ["seq.fc1", "seq.fc2"]), # length mismatch
([0.3, 0.4], ["re:.*fc1", "re:.*fc2"]), # regex not supported
],
)
def test_invalid_layerwise_recipes_raise_exceptions(sparsity, targets):
setup_modifier_factory()
model = LinearNet()

kwargs = dict(
sparsity=sparsity,
block_size=128,
quantize=False,
targets=targets,
)
modifier = SparseGPTModifierPyTorch(**kwargs)
testing_harness = LifecyleTestingHarness(model=model)

# confirm invalid layerwise recipes fail at initialization
with pytest.raises(ValueError):
modifier.initialize(testing_harness.get_state())


def test_successful_layerwise_recipe():
setup_modifier_factory()
model = LinearNet()

sparsities = [0.5, 0.2]
targets = ["seq.fc1", "seq.fc2"]
kwargs = dict(sparsity=sparsities, block_size=128, quantize=False, targets=targets)
modifier = SparseGPTModifierPyTorch(**kwargs)
modifier._validate_layerwise_sparisity()
modifier.model = ModifiableModel(framework=Framework.pytorch, model=model)
found_compressible_layers = modifier.compressible_layers()

# ensure layers names successfully match up with model
assert len(found_compressible_layers) == len(targets)


def test_create_default_quant_modifier():
setup_modifier_factory()
kwargs = dict(sparsity=0.5, block_size=128, quantize=True)
Expand All @@ -27,7 +73,7 @@ def test_create_default_quant_modifier():
assert modifier.quantization_modifier_ is None

testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.pre_initialize_structure(testing_harness.get_state())
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)

Expand Down Expand Up @@ -59,7 +105,7 @@ def test_set_quant_if_modifer_already_exists():
kwargs = dict(sparsity=0.5, block_size=128, quantize=False)
modifier = SparseGPTModifierPyTorch(**kwargs)
assert not modifier.quantize
modifier.pre_initialize_structure(testing_harness.get_state())
modifier.on_initialize_structure(testing_harness.get_state())

# quantization modifier not owned by SparseGPT
assert modifier.quantization_modifier_ is None
Expand Down Expand Up @@ -95,7 +141,7 @@ def test_set_quant_in_sparsegpt():
assert modifier.quantization_modifier_ is None

testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.pre_initialize_structure(testing_harness.get_state())
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)

Expand Down

0 comments on commit e20927f

Please sign in to comment.