Skip to content

Commit

Permalink
Merge branch 'main' into layer-wise-sparsity
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 26, 2023
2 parents 3b3fb2f + 972529c commit 1c30bd8
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 77 deletions.
8 changes: 8 additions & 0 deletions src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ def set_param(self, target: str, param: PT):
:param param: the param instance to set
"""
raise NotImplementedError()

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
raise NotImplementedError()
9 changes: 9 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_layers_params,
get_param,
get_params,
qat_active,
set_layer,
set_param,
)
Expand Down Expand Up @@ -94,3 +95,11 @@ def set_param(self, target: str, param: Parameter):
:param param: the parameter to set
"""
return set_param(target, param, self.model)

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
return qat_active(self.model)
85 changes: 79 additions & 6 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
"""
Expand All @@ -34,7 +37,9 @@ class SparseGPTModifier(Modifier):
:param sparsity: Sparsity to compress model to
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not model is quantized (affects layer names)
:param quantize: Whether or not to quantize weights during SparseGPT. Set to True
to quantize using an existing quantization modifier, or pass in the configuration
for a quantization modifier if one does not already exist in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param sequential_update: Whether or not to update weights sequentially by layer,
Expand All @@ -50,17 +55,85 @@ class SparseGPTModifier(Modifier):

sparsity: Union[float, List[float]]
block_size: int
quantize: bool
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
prunen: Optional[int] = 0
prunem: Optional[int] = 0
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: List = None
quantization_modifier_: Any = None

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_intialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_intialize_structure(state, **kwargs)

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier
def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
Expand Down
31 changes: 7 additions & 24 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import Module

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
Expand Down Expand Up @@ -46,31 +45,9 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
"""

model: Any = None
compressible_layers_: List = None
device_: str = "cuda:0"
finalization_kwargs_: Dict = None

def compressible_layers(self) -> List[Module]:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)

# Compare length to sparsities again in case one of the provided layers was
# invalid or not compressible
if isinstance(self.sparsity, List) and len(self.sparsity) != len(
compressible_dict
):
raise ValueError(
"Number of compressible layers must match the number of "
f"sparsities. Got {len(compressible_dict)} layers and "
f"{len(self.sparsity)} sparsities"
)

return [v for _, v in compressible_dict.items()]

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
Expand All @@ -79,6 +56,10 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
"""
self._validate_layerwise_sparisity()

if not self.initialized_structure_:
self.on_intialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
self.finalization_kwargs_ = {}
modifiable_model = state.model
calibration_dataloader = state.data.calib
Expand Down Expand Up @@ -172,9 +153,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool:
:param state: un-used, for matching spec of Modifier base class
"""
use_cache = self.finalization_kwargs_.get("use_cache", False)
self.model.apply(torch.quantization.disable_observer)
self.model.config.use_cache = use_cache

if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

return True

def compress_bottom(
Expand Down
39 changes: 3 additions & 36 deletions src/sparseml/modifiers/obcq/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import Dict, List

import torch
import torch.nn as nn
from torch.nn import Module

from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT
from sparseml.pytorch.utils.helpers import get_dependency_order
from sparseml.utils.pytorch.module import get_prunable_layers


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,14 +60,8 @@ def compressible_modules(self) -> Dict:
:return: dictionary of compressible modules
"""
quantize = self.args.get("quantize", False)
if quantize:
# The layer names are changed due to quantization modifiers, therefore
# we need a slightly different func to retrieve layers
modules = _find_quant_layers(self.layer)
else:
modules = _find_layers(self.layer)
return modules
compressible_layers = get_prunable_layers(self.layer)
return compressible_layers

def pre_compress_parallel(self, **kwargs) -> Dict:
"""
Expand Down Expand Up @@ -217,30 +211,3 @@ def tmp(_, inp, out):
blocksize=self.args["blocksize"],
)
gpts.free()


def _find_quant_layers(
module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name=""
):
res = {}
# search for QAT versions of layers
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res


def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
22 changes: 11 additions & 11 deletions src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
test_stage:
obcq_modifiers:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
quantize: True
quantize:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
percdamp: 0.01
prunen: 0
prunem: 0
Expand Down
16 changes: 16 additions & 0 deletions src/sparseml/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"get_terminal_layers",
"get_prunable_layers",
"get_quantizable_layers",
"qat_active",
"get_layers_params",
]

Expand Down Expand Up @@ -241,6 +242,21 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]:
return quantizable


def qat_active(module: Module) -> bool:
"""
Determines if any layers in the model have quantization enabled by checking for
weight_fake_quant attributes
:param module: PyTorch model to check for quantization
:return: True if quantization is active anywhere in the model, False otherwise
"""
for _, layer in module.named_modules():
if isinstance(layer, torch.quantization.FakeQuantize):
return True

return False


def get_layers_params(
targets: Union[str, List[str]], module: Module
) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]:
Expand Down
13 changes: 13 additions & 0 deletions tests/sparseml/pytorch/modifiers/obcq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit 1c30bd8

Please sign in to comment.