Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Branch] Quant modifier UX #2263

Merged
merged 8 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ initial_sparsity_stage:
sparsity: 0.5
block_size: 128
sequential_update: False
quantize: False
percdamp: 0.01
mask_structure: "0:0"
targets: [
Expand All @@ -24,7 +23,6 @@ next_sparsity_stage:
sparsity: 0.7
block_size: 128
sequential_update: False
quantize: False
percdamp: 0.01
mask_structure: "0:0"
targets: [
Expand Down
154 changes: 79 additions & 75 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, List, Optional, Union
from typing import Dict, List, Optional, Union

from sparseml.core.factory import ModifierFactory
from sparseml.core import Modifier
from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(WandaPruningModifier):
class SparseGPTModifier(Modifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model

Expand All @@ -41,84 +38,91 @@ class SparseGPTModifier(WandaPruningModifier):
- on_finalize
- LayerCompressor.revert_layer_wrappers()

:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not to quantize weights during SparseGPT. Set to
True to quantize using an existing quantization modifier, or pass in the
configuration for a quantization modifier if one does not already exist
in the recipe
:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
in the paper https://arxiv.org/pdf/2310.05175
:param owl_m: Number of outliers to use for OWL
:param owl_lmbda: Lambda value to use for OWL
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param preserve_sparsity_mask: Whether or not to preserve the sparsity mask
during when applying sparsegpt, this becomes useful when starting from a
previously pruned model, defaults to False.
"""

block_size: int = 128
quantize: Union[bool, Dict] = False
sparsity: Union[float, List[float]] = 0.0
sparsity_profile: Optional[str] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
quantization_modifier_: Any = None
preserve_sparsity_mask: bool = False
prunen_: Optional[int] = None
prunem_: Optional[int] = None
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
Check the model's quantization state matches that expected by this modifier,
adding a default quantization scheme if needed
Initialize the structure of the model for compression.
This modifier does not modifiy the model structure, so this method
is a no-op.

:param state: session state storing input model and calibration data
"""
return True

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
compressible layer names

:precondition: self.model is set and is a `ModifiableModel`
:precondition: The `ModifiableModel` implements a `get_layers`
method
:return: dictionary of modules to compress
"""
if not isinstance(self.model, ModifiableModel):
raise ValueError(
"`self.model` must be a ModifiableModel to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
)

return self.model.get_layers(self.targets)

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

target_layers = list(self.compressible_layers_.keys())

if len(target_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(target_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

def on_finalize(self, state: State, **kwargs):
"""
Nothing to do on finalize, on this level.
Quantization Modifier if any will be finalized in the subclass

:param state: session state storing input model and calibration data
:param kwargs: additional arguments
:return: True
"""
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)
return True
Loading
Loading