Skip to content

Commit

Permalink
Refactor to use WandaLayerCompressor
Browse files Browse the repository at this point in the history
Update WrappedGPT
  • Loading branch information
rahul-tuli committed Nov 16, 2023
1 parent a850ee6 commit 0ad14bc
Show file tree
Hide file tree
Showing 6 changed files with 453 additions and 181 deletions.
3 changes: 2 additions & 1 deletion src/sparseml/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.


from typing import List, Union
from typing import List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.model.base import ModifiableModel
Expand Down Expand Up @@ -45,6 +45,7 @@ class WandaPruningModifier(Modifier):
sparsity: Union[float, List[float]]
mask_structure: str = "0:0"
targets: Union[str, List[str], None] = ALL_TOKEN
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
Expand Down
227 changes: 141 additions & 86 deletions src/sparseml/modifiers/pruning/wanda/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
# limitations under the License.

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch

from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier
from sparseml.modifiers.pruning.wanda.utils.helpers import (
find_layers,
prepare_calibration_input,
)
from sparseml.modifiers.pruning.wanda.utils.wrapped_gpt import WrappedGPT
from sparseml.modifiers.pruning.wanda.utils.layer_compressor import WandaLayerCompressor


_LOGGER = logging.getLogger(__name__)
Expand All @@ -33,92 +32,148 @@ class WandaPruningModifierPyTorch(WandaPruningModifier):
PyTorch implementation of WandaPruningModifier
"""

model: Optional[ModifiableModel] = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None

def on_initialize(self, state: State, **kwargs) -> bool:
modifiable_model = state.model
pytorch_model = modifiable_model.model
use_cache = pytorch_model.config.use_cache

# set use_cache to False to avoid OOM
pytorch_model.config.use_cache = False

_LOGGER.info("Preparing calibration data")
calibration_dataloader = state.data.calib
device = state.hardware.device
pytorch_model.to(device)
with torch.no_grad():
inps, outs, attention_mask, position_ids = prepare_calibration_input(
pytorch_model, calibration_dataloader, device
)
"""
Initialize and run the WANDA algorithm on the current state
layers = pytorch_model.model.layers
for i in range(len(layers)):
layer = layers[i]
subset = find_layers(layer)
wrapped_layers = {}
for name in subset:
wrapped_layers[name] = WrappedGPT(subset[name])

def add_batch(name):
def tmp(_, inp, out):
wrapped_layers[name].add_batch(inp[0].data, out.data)

return tmp

handles = []
for name in wrapped_layers:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(len(calibration_dataloader)):
with torch.no_grad():
outs[j] = layer(
inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
for h in handles:
h.remove()
if self.mask_structure == "unstructured":
prune_n = prune_m = 0
else:
prune_n, prune_m = tuple(map(int, self.mask_structure.split(":")))

for name in subset:
_LOGGER.info(f"pruning layer {i} name {name}")
W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(
wrapped_layers[name].scaler_row.reshape((1, -1))
)
:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparsity()

self.initialize_wanda(state, **kwargs)

W_mask = (
torch.zeros_like(W_metric) == 1
) # initialize a mask to be all False
if prune_n != 0:
# structured n:m sparsity
for ii in range(W_metric.shape[1]):
if ii % prune_m == 0:
tmp = W_metric[:, ii : (ii + prune_m)].float()
W_mask.scatter_(
1,
ii + torch.topk(tmp, prune_n, dim=1, largest=False)[1],
True,
)
else:
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:, : int(W_metric.shape[1] * self.sparsity)]
W_mask.scatter_(1, indices, True)

subset[name].weight.data[W_mask] = 0 # set weights to zero

for j in range(len(calibration_dataloader)):
with torch.no_grad():
outs[j] = layer(
inps[j].unsqueeze(0),
attention_mask=attention_mask,
position_ids=position_ids,
)[0]
inps, outs = outs, inps

pytorch_model.config.use_cache = use_cache
# run wanda on calibration data
self.apply_wanda(dataloader=state.data.calib)
torch.cuda.empty_cache()
return True

def initialize_wanda(self, state: State, **kwargs):
"""
Setup for WANDA, initializes the model, device,
and other parameters, also initilializes the
compressible layers of model, and sets the device
:param state: session state storing input model and calibration data
"""
self.model = state.model
self.compressible_layers_ = self.compressible_layers()
self.device_ = self._set_device(device=state.hardware.device)
self.layer_prefix_ = self.model.layer_prefix
self._infer_mask_block_size()

@torch.no_grad()
def apply_wanda(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run Wanda on the loaded model, using dataloader as calibration data
:param dataloader: calibration data for WANDA
"""
accum_kwargs = {"dataloader": dataloader}
pytorch_model = self.model.model

# Step 0: Pass the calibration data through the (compressed) bottom part of the
# network, capturing the outputs which will become the inputs to the first
# decoder layer. Also return attention_mask as part of kwargs
extras = self.compress_bottom(
dev=self.device_,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)

# Step 1: Sequentially prune decoder layers
inputs = None
num_layers = len(self.compressible_layers_)
for idx, layer in enumerate(self.compressible_layers_):
if "outputs" not in accum_kwargs:
raise RuntimeError(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
}
# Prune using WandaGPT
layer_compressor = WandaLayerCompressor(
model=pytorch_model,
layer=layer,
layer_index=idx,
inputs=inputs,
args=args,
)
layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs)
accum_kwargs.update(layer_kwargs)

def compress_bottom(
self,
dataloader: List = None,
nsamples: int = None,
dev: str = "cuda:0",
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs
:param dataloader: calibration data to pass through the model
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
model=self.model.model,
dataloader=dataloader,
device=dev,
nsamples=nsamples,
target_ids=None,
layer_prefix=layer_prefix,
)

outputs = cached_inputs.pop("inputs")
outputs = [o[0] for o in outputs]
cached_inputs.update({"outputs": outputs})
return cached_inputs

def on_finalize(self, state: State, **kwargs):
return True

def _set_device(self, device: str):
if "cuda" in device and not torch.cuda.is_available():
self.device_ = "cpu"
else:
self.device_ = device

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.
:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/pruning/wanda/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@
# limitations under the License.

# flake8: noqa
from .helpers import *
from .wrapped_gpt import *
83 changes: 0 additions & 83 deletions src/sparseml/modifiers/pruning/wanda/utils/helpers.py

This file was deleted.

Loading

0 comments on commit 0ad14bc

Please sign in to comment.