Skip to content

Commit

Permalink
[Wanda Refactor] Wanda/OBCQ Modifier Refactor (#1887)
Browse files Browse the repository at this point in the history
* Define GPT contract

* rename tmp -> batch_size

* Define LayerCompressor Contract

* Rename gpt_helpers to gpts
Fix some docstrings

* add named argument to function call

* Wanda/OBCQ refactor

* propagate target-ids

* Address review comments from
* #1885
* #1886
  • Loading branch information
rahul-tuli authored Dec 18, 2023
1 parent 1d1aaca commit 5f24ff9
Show file tree
Hide file tree
Showing 10 changed files with 414 additions and 561 deletions.
42 changes: 2 additions & 40 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
class SparseGPTModifier(WandaPruningModifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model
Expand Down Expand Up @@ -54,18 +53,14 @@ class SparseGPTModifier(Modifier):
has been deprecated and will be removed in a future release
"""

sparsity: Union[float, List[float]]
block_size: int
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
mask_structure: str = "0:0"
prunen_: Optional[int] = None
prunem_: Optional[int] = None
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

def __post_init__(self):
Expand All @@ -75,15 +70,6 @@ def __post_init__(self):
"removed in a future release"
)

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
Expand Down Expand Up @@ -143,27 +129,3 @@ def _build_quant_modifier_from_dict(self, quant_config, framework):
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
166 changes: 21 additions & 145 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,27 @@
# limitations under the License.


import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from functools import partial
from typing import Any, Optional

import torch

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.base import SparseGPTModifier
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.obcq.utils.layer_compressor import LayerCompressor
from sparseml.modifiers.obcq.utils.layer_compressor import OBCQLayerCompressor
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch


_LOGGER = logging.getLogger(__name__)
__all__ = ["SparseGPTModifierPyTorch"]


class SparseGPTModifierPyTorch(SparseGPTModifier):
class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
"""
Pytorch implementation of SparseGPT
Lifecycle:
- on_initialize
- initialize_obcq
- setup
- compressible_layers
- apply_obcq
- prune
- compress_bottom
- LayerCompressor.compress
- on_finalize
Expand All @@ -47,6 +44,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
model: Any = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
layer_compressor_class_ = OBCQLayerCompressor

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand All @@ -60,152 +58,30 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
modifiable_model = state.model
calibration_dataloader = state.data.calib
device = state.hardware.device

self.initialize_obcq(modifiable_model, device)
self.apply_obcq(calibration_dataloader)

return True
# attach target_ids to `compress_bottom` for OBCQ
# this must be done before calling super().on_initialize

def initialize_obcq(
self,
model: "ModifiableModel",
device: Optional[str] = "cuda:0",
):
"""
Setup for SparseGPT, initialize the the compressible layers of model, and set
the device
self.compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids)
return super().on_initialize(state=state, **kwargs)

:param model: PyTorch model to sparsify
:param device: device to run sparsification on, preferably a GPU
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_prefix_ = model.layer_prefix
self.model = self.model.model
self._set_device(device)
self._infer_mask_block_size()

@torch.no_grad()
def apply_obcq(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run OBCQ on the loaded model, using dataloader as calibration data
:param dataloader: calibration data for OBCQ
"""
accum_kwargs = {"dataloader": dataloader}

# Step 0: Pass the calibration data through the (compressed) bottom part of the
# network, capturing the outputs which will become the inputs to the first
# decoder layer. Also return attention_mask as part of kwargs
extras = self.compress_bottom(
dev=self.device_,
target_ids=self.target_ids,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)

# Step 1: Sequentially prune/quantize decoder layers
inputs = None
num_layers = len(self.compressible_layers_)
for idx, layer in enumerate(self.compressible_layers_):
if "outputs" not in accum_kwargs:
raise RuntimeError(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
def _get_compression_args(self, layer_sparsity):
return {
**super()._get_compression_args(layer_sparsity=layer_sparsity),
**{
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"sequential_update": self.sequential_update,
"quantize": self.quantize,
}
layer_compressor = LayerCompressor(self.model, layer, idx, inputs, args)
},
}

# Prune/quantize using SparseGPT
layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs)
accum_kwargs.update(layer_kwargs)

def on_finalize(self, state: "State", **kwargs) -> bool:
def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the observers used by the OBCQ algorithm and set kv-cache configuration
:param state: un-used, for matching spec of Modifier base class
"""

if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

return True

def compress_bottom(
self,
dataloader: List = None,
nsamples: int = None,
dev: str = "cuda:0",
target_ids: List[str] = None,
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs
:param dataloader: calibration data to pass through the model
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param target_ids: list of keys in model output to cache, NOTE: this argument
has been deprecated and will be removed in a future release
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
model=self.model,
dataloader=dataloader,
device=dev,
nsamples=nsamples,
target_ids=target_ids,
layer_prefix=layer_prefix,
)

outputs = cached_inputs.pop("inputs")
outputs = [o[0] for o in outputs]
cached_inputs.update({"outputs": outputs})
return cached_inputs

def _set_device(self, device: str):
if "cuda" in device and not torch.cuda.is_available():
self.device_ = "cpu"
else:
self.device_ = device

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.
:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
return super().on_finalize(state, **kwargs)
Loading

0 comments on commit 5f24ff9

Please sign in to comment.