Skip to content

Commit

Permalink
SmoothQuant Modifier for OBCQ (#1758)
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat authored and bfineran committed Nov 16, 2023
1 parent cf81b17 commit 312f2f0
Show file tree
Hide file tree
Showing 13 changed files with 630 additions and 34 deletions.
12 changes: 11 additions & 1 deletion src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, TypeVar, Union
from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union

from sparseml.core.framework import Framework
from sparseml.core.framework_object import MultiFrameworkObject
Expand Down Expand Up @@ -117,6 +117,16 @@ def set_param(self, target: str, param: PT):
"""
raise NotImplementedError()

def get_matching_layer(
self, target: str, name_to_match: str, model: LT
) -> Optional[Tuple[str, LT]]:
"""
:param target: regex layer name to target when searching model
:param name_to_match: name to match targets to
:param model: model to search for targets
"""
raise NotImplementedError()

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
Expand Down
11 changes: 11 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_layer,
get_layers,
get_layers_params,
get_matching_layer,
get_param,
get_params,
qat_active,
Expand Down Expand Up @@ -96,6 +97,16 @@ def set_param(self, target: str, param: Parameter):
"""
return set_param(target, param, self.model)

def get_matching_layer(
self, target: str, name_to_match: str, model: Module
) -> Optional[Tuple[str, Module]]:
"""
:param target: regex layer name to target when searching model
:param name_to_match: name to match targets to
:param model: model to search for targets
"""
return get_matching_layer(target, name_to_match, model)

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from .obcq import *
from .pruning import *
from .quantization import *
from .smoothquant import *
28 changes: 7 additions & 21 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.

import logging
from itertools import cycle
from typing import Any, Callable, Dict, Optional
from typing import Any, Dict, Optional

import torch
from torch.nn import Module
Expand All @@ -35,7 +34,7 @@
raise_if_torch_quantization_not_available,
set_quantization_schemes,
)
from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -191,26 +190,13 @@ def _calibrate(self, module: Module):
module_training = module.training
module.eval()

forward_fn: Callable = (
self.calibration_function_
if self.calibration_function_
else tensors_module_forward
run_calibration_forward(
module,
self.calibration_dataloader_,
self.num_calibration_steps,
self.calibration_function_,
)

model_device = next(module.parameters()).device
_dataloader = (
self.calibration_dataloader_
if self.num_calibration_steps is None
else cycle(self.calibration_dataloader_)
)

for batch_idx, batch in enumerate(_dataloader):
if self.num_calibration_steps and batch_idx >= self.num_calibration_steps:
break
batch = tensors_to_device(batch, model_device)
with torch.no_grad():
forward_fn(batch, module=module)

if module_training:
module.train()
else:
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/smoothquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# flake8: noqa

# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import *
183 changes: 183 additions & 0 deletions src/sparseml/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from dataclasses import dataclass
from typing import Dict, Generic, List, Optional, Tuple, TypeVar

from pydantic import Field

from sparseml.core import Modifier
from sparseml.core.model import ModifiableModel
from sparseml.core.model.base import LT
from sparseml.core.state import Event, State


VT = TypeVar("VT") # represents a generic vector

__all__ = ["SmoothQuantScale", "SmoothQuantMapping", "SmoothQuantModifier"]


@dataclass
class SmoothQuantScale(Generic[VT]):
"""
Dataclass for storing the channel-wise minimum and maximum values for a layer. This
is updated each forward pass during calibration
:param min_channel_vals: minimum output value seen so far, per channel
:param max_channel_vals: maximum output value seen so far, per channel
"""

min_channel_vals: VT
max_channel_vals: VT


@dataclass
class SmoothQuantMapping(Generic[LT]):
"""
Dataclass for storing the mapping between an activation layer and the following
weights that must be balanced during smoothing
:param smooth_name: name of the activation layer
:param smooth_layer: PyTorch module storing the activation layer
:param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be
balanced to offset the smoothing of smooth_layer
"""

smooth_name: str
smooth_layer: LT
balance_layers: List[LT]


class SmoothQuantModifier(Modifier):
"""
Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This
modifier performs a channel-wise smoothing of outliers in activations, making them
easier to quantize by reducing the dynamic range. The smoothing is offset by
applying the inverse operation to the next layer of weights, making the weights
slightly more difficult to quantize.
Because this modifier manipulates the weights of the model, it can only be used in
in one-shot and not during training. Activation ranges are determined by running a
small set of calibration data through the model.
example recipe:
```yaml
SmoothQuantModifier:
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"],
[["re:.*fc1"], "re:.*final_layer_norm"]
]
ignore: ["model.decoder.final_layer_norm"]
```
:param smoothing_strength: alpha, intensity of smoothing to perform (0-1 range)
:param mappings: list activation layers to smooth, and the which layers to offset
the smoothing to for each activation
:param ignore: list of layers to ignore, even if they match a regex in mappings
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
"""

smoothing_strength: float = Field(validation_alias="alpha")
mappings: List[Tuple]
ignore: Optional[List[str]] = None
num_calibration_steps: Optional[int] = None

resolved_mappings_: Optional[List] = None
scales_: Optional[Dict] = None

def on_initialize_structure(self, state: State, **kwargs):
pass # nothing needed for this modifier

def on_initialize(self, state: State, **kwargs) -> bool:
"""
Initialize and run SmoothQuant on the given state
:param state: state to run SmoothQuant on
:return: True on a successful run, False otherwise
"""
if self.end and self.end != -1:
raise ValueError(
"SmoothQuantModifier can only be applied during one-shot. Expected end"
" to be None or -1, got {}".format(self.end)
)
if self.start and self.start != -1:
raise ValueError(
"SmoothQuantModifier can only be applied during one-shot. Expected "
"start to be None or -1, got {}".format(self.start)
)

self.ignore = [] if not self.ignore else self.ignore
self.resolved_mappings_ = self._resolve_mappings(state.model)
self.scales_ = {}

def _resolve_mappings(self, model: ModifiableModel) -> List:
"""
Transforms the list of activations to smooth and their corresponding weights
into SmoothQuantMapping objects, resolving regular expressions.
For each activation in the mapping list, we find the corresponding weight to
balance by searching for the longest substring. For instance, if our balance
weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we
would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and
repeat for model.layer.1 and so on
"""
resolved_mappings = []
for to_balance, to_smooth in self.mappings:
to_smooth_layers = model.get_layers(to_smooth)
for layer_name, smooth_layer in to_smooth_layers.items():
if layer_name not in self.ignore:
balance_layers = []
for balance_suffix in to_balance:
# find the submodule that matches the activation layer
_, balance_layer = model.get_matching_layer(
balance_suffix, layer_name, model.model
)
if balance_layer:
balance_layers.append(balance_layer)
# each mapping can contain multiple layers to balance, but only
# one layer to smooth
mapping = SmoothQuantMapping(
layer_name, smooth_layer, balance_layers
)
resolved_mappings.append(mapping)
return resolved_mappings

def on_start(self, state: State, event: Event, **kwargs):
pass

def on_update(self, state: State, event: Event, **kwargs):
pass

def on_end(self, state: State, event: Event, **kwargs):
pass

def on_event(self, state: State, event: Event, **kwargs):
pass

def on_finalize(self, state: State, **kwargs) -> bool:
"""
Clean up by clearing the scale and mapping data
:param state: unused
:return: True
"""
if self.scales_ is not None:
self.scales_.clear()
if self.resolved_mappings_ is not None:
self.resolved_mappings_.clear()

return True
Loading

0 comments on commit 312f2f0

Please sign in to comment.