Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wanda #1834

Merged
merged 17 commits into from
Dec 28, 2023
Merged

Wanda #1834

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 2 additions & 40 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
class SparseGPTModifier(WandaPruningModifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model

Expand Down Expand Up @@ -54,18 +53,14 @@ class SparseGPTModifier(Modifier):
has been deprecated and will be removed in a future release
"""

sparsity: Union[float, List[float]]
block_size: int
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
mask_structure: str = "0:0"
prunen_: Optional[int] = None
prunem_: Optional[int] = None
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

def __post_init__(self):
Expand All @@ -75,15 +70,6 @@ def __post_init__(self):
"removed in a future release"
)

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names

:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
Expand Down Expand Up @@ -143,27 +129,3 @@ def _build_quant_modifier_from_dict(self, quant_config, framework):
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
167 changes: 23 additions & 144 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,27 @@
# limitations under the License.


import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from functools import partial
from typing import Any, Optional

import torch

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.base import SparseGPTModifier
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.obcq.utils.layer_compressor import LayerCompressor
from sparseml.modifiers.obcq.utils.layer_compressor import OBCQLayerCompressor
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch


_LOGGER = logging.getLogger(__name__)
__all__ = ["SparseGPTModifierPyTorch"]


class SparseGPTModifierPyTorch(SparseGPTModifier):
class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
"""
Pytorch implementation of SparseGPT

Lifecycle:
- on_initialize
- initialize_obcq
- setup
- compressible_layers
- apply_obcq
- prune
- compress_bottom
- LayerCompressor.compress
- on_finalize
Expand All @@ -47,6 +44,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
model: Any = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
layer_compressor_class_ = OBCQLayerCompressor

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand All @@ -60,152 +58,33 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
modifiable_model = state.model
calibration_dataloader = state.data.calib
device = state.hardware.device

self.initialize_obcq(modifiable_model, device)
self.apply_obcq(calibration_dataloader)

return True
# attach target_ids to `compress_bottom` for OBCQ
# this must be done before calling super().on_initialize

def initialize_obcq(
self,
model: "ModifiableModel",
device: Optional[str] = "cuda:0",
):
"""
Setup for SparseGPT, initialize the the compressible layers of model, and set
the device
compress_bottom = partial(self.compress_bottom, target_ids=self.target_ids)

:param model: PyTorch model to sparsify
:param device: device to run sparsification on, preferably a GPU
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_prefix_ = model.layer_prefix
self.model = self.model.model
self._set_device(device)
self._infer_mask_block_size()

@torch.no_grad()
def apply_obcq(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run OBCQ on the loaded model, using dataloader as calibration data
# we need setattr here because of Pydantic's internal data model
object.__setattr__(self, "compress_bottom", compress_bottom)
return super().on_initialize(state=state, **kwargs)

:param dataloader: calibration data for OBCQ
"""
accum_kwargs = {"dataloader": dataloader}

# Step 0: Pass the calibration data through the (compressed) bottom part of the
# network, capturing the outputs which will become the inputs to the first
# decoder layer. Also return attention_mask as part of kwargs
extras = self.compress_bottom(
dev=self.device_,
target_ids=self.target_ids,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)

# Step 1: Sequentially prune/quantize decoder layers
inputs = None
num_layers = len(self.compressible_layers_)
for idx, layer in enumerate(self.compressible_layers_):
if "outputs" not in accum_kwargs:
raise RuntimeError(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
def _get_compression_args(self, layer_sparsity):
return {
**super()._get_compression_args(layer_sparsity=layer_sparsity),
**{
"blocksize": self.block_size,
"percdamp": self.dampening_frac,
"sequential_update": self.sequential_update,
"quantize": self.quantize,
}
layer_compressor = LayerCompressor(self.model, layer, idx, inputs, args)
},
}

# Prune/quantize using SparseGPT
layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs)
accum_kwargs.update(layer_kwargs)

def on_finalize(self, state: "State", **kwargs) -> bool:
def on_finalize(self, state: State, **kwargs) -> bool:
"""
disable the observers used by the OBCQ algorithm and set kv-cache configuration

:param state: un-used, for matching spec of Modifier base class
"""

if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

return True

def compress_bottom(
self,
dataloader: List = None,
nsamples: int = None,
dev: str = "cuda:0",
target_ids: List[str] = None,
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs

:param dataloader: calibration data to pass through the model
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param target_ids: list of keys in model output to cache, NOTE: this argument
has been deprecated and will be removed in a future release
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
model=self.model,
dataloader=dataloader,
device=dev,
nsamples=nsamples,
target_ids=target_ids,
layer_prefix=layer_prefix,
)

outputs = cached_inputs.pop("inputs")
outputs = [o[0] for o in outputs]
cached_inputs.update({"outputs": outputs})
return cached_inputs

def _set_device(self, device: str):
if "cuda" in device and not torch.cuda.is_available():
self.device_ = "cpu"
else:
self.device_ = device

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.

:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
return super().on_finalize(state, **kwargs)
Loading
Loading