Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wanda #1834

Merged
merged 17 commits into from
Dec 28, 2023
Merged

Wanda #1834

Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/sparseml/modifiers/pruning/wanda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
101 changes: 101 additions & 0 deletions src/sparseml/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN


__all__ = ["WandaPruningModifier"]


class WandaPruningModifier(Modifier):
"""
Modifier for applying the one-shot WANDA algorithm to a model
from the paper: https://arxiv.org/abs/2306.11695

Life-cycle:
- initialze
- compress
- finalize

:param sparsity: Sparsity to compress model to
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
"""

sparsity: Union[float, List[float]]
mask_structure: str = "0:0"
targets: Union[str, List[str], None] = ALL_TOKEN
compressible_layers_: Optional[List] = None

def on_initialize_structure(self, state: State, **kwargs):
"""
This modifier does not alter the model structure.
This method is a no-op.

:param state: Unused, kept to conform to the parent method signature
:param kwargs: Unused, kept to conform to the parent method signature
"""

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of
compressible layer names

:precondition: self.model is set and is a `ModifiableModel`
:precondition: The `ModifiableModel` implements a `get_layers`
method
:return: list of modules to compress
"""
if not isinstance(self.model, ModifiableModel):
raise ValueError(
"`self.model` must be a ModifiableModel to use "
f"the WANDA modifier but got {type(self.model)} instead"
)

compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {type(self.targets)}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
179 changes: 179 additions & 0 deletions src/sparseml/modifiers/pruning/wanda/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch

from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.utils.helpers import cache_attention_inputs
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier
from sparseml.modifiers.pruning.wanda.utils.layer_compressor import WandaLayerCompressor


_LOGGER = logging.getLogger(__name__)


class WandaPruningModifierPyTorch(WandaPruningModifier):
"""
PyTorch implementation of WandaPruningModifier
"""

model: Optional[ModifiableModel] = None
device_: str = "cuda:0"
layer_prefix_: Optional[str] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run the WANDA algorithm on the current state

:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparsity()

self.initialize_wanda(state, **kwargs)

# run wanda on calibration data
self.apply_wanda(dataloader=state.data.calib)
torch.cuda.empty_cache()
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
return True

def initialize_wanda(self, state: State, **kwargs):
"""
Setup for WANDA, initializes the model, device,
and other parameters, also initilializes the
compressible layers of model, and sets the device

:param state: session state storing input model and calibration data
"""
self.model = state.model
self.compressible_layers_ = self.compressible_layers()
self.device_ = self._set_device(device=state.hardware.device)
self.layer_prefix_ = self.model.layer_prefix
self._infer_mask_block_size()

@torch.no_grad()
def apply_wanda(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run Wanda on the loaded model, using dataloader as calibration data

:param dataloader: calibration data for WANDA
"""
accum_kwargs = {"dataloader": dataloader}
pytorch_model = self.model.model

# Step 0: Pass the calibration data through the (compressed) bottom part of the
# network, capturing the outputs which will become the inputs to the first
# decoder layer. Also return attention_mask as part of kwargs
extras = self.compress_bottom(
dev=self.device_,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)

# Step 1: Sequentially prune decoder layers
inputs = None
num_layers = len(self.compressible_layers_)
for idx, layer in enumerate(self.compressible_layers_):
if "outputs" not in accum_kwargs:
raise RuntimeError(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": layer_sparsity,
"prunen": self.prunen_,
"prunem": self.prunem_,
}
# Prune using WandaGPT
layer_compressor = WandaLayerCompressor(
model=pytorch_model,
layer=layer,
layer_index=idx,
inputs=inputs,
args=args,
)
layer_kwargs = layer_compressor.compress(dev=self.device_, **accum_kwargs)
accum_kwargs.update(layer_kwargs)

def compress_bottom(
self,
dataloader: List = None,
nsamples: int = None,
dev: str = "cuda:0",
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs

:param dataloader: calibration data to pass through the model
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
model=self.model.model,
dataloader=dataloader,
device=dev,
nsamples=nsamples,
target_ids=None,
layer_prefix=layer_prefix,
)

outputs = cached_inputs.pop("inputs")
outputs = [o[0] for o in outputs]
cached_inputs.update({"outputs": outputs})
return cached_inputs

def on_finalize(self, state: State, **kwargs):
return True

def _set_device(self, device: str):
if "cuda" in device and not torch.cuda.is_available():
self.device_ = "cpu"
else:
self.device_ = device

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.

:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))
15 changes: 15 additions & 0 deletions src/sparseml/modifiers/pruning/wanda/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa
Loading
Loading