diff --git a/src/sparseml/modifiers/smoothquant/base.py b/src/sparseml/modifiers/smoothquant/base.py index ee2ed8883ea..f4a3b71b478 100644 --- a/src/sparseml/modifiers/smoothquant/base.py +++ b/src/sparseml/modifiers/smoothquant/base.py @@ -76,15 +76,13 @@ class SmoothQuantModifier(Modifier): :param mappings: list activation layers to smooth, and the which layers to offset the smoothing to for each activation :param ignore: list of layers to ignore, even if they match a regex in mappings - :param logarithmic_equalization: Whether to use a logarithmic scale for smoothing :param num_calibration_steps: number of samples to use for calibration, or None to use the whole dataset """ - smoothing_strength: float = Field(..., alias="alpha") + smoothing_strength: float = Field(..., validation_alias="alpha") mappings: List[Tuple] ignore: Optional[List[str]] = None - logarithmic_equalization: Optional[bool] = False num_calibration_steps: Optional[int] = None resolved_mappings_: Dict = None diff --git a/src/sparseml/modifiers/smoothquant/pytorch.py b/src/sparseml/modifiers/smoothquant/pytorch.py index ad953e40d1a..ce87b34498e 100644 --- a/src/sparseml/modifiers/smoothquant/pytorch.py +++ b/src/sparseml/modifiers/smoothquant/pytorch.py @@ -16,6 +16,7 @@ from typing import Callable, List, Optional import torch +from torch.nn import Module from sparseml.core import State from sparseml.core.model.pytorch import ModifiableModelPyTorch @@ -190,19 +191,7 @@ def _apply_smoothing(self): smooth_layer = mapping.smooth_layer balance_layers = mapping.balance_layers - # get the channel-wise dynamic range for each layer to be balanced - weight_scales = [] - for layer in balance_layers: - scale = layer.weight.abs().max(dim=0, keepdim=True)[0] - weight_scales.append(scale) - weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] - - # calculate the amount of smoothing to apply - # s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha) - # where j is the input channel, alpha is smoothing strength - scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow( - 1 - self.smoothing_strength - ) + scales = self._calculate_smoothing_scales(balance_layers, activation_scales) # invert the smoothing in the following layers for layer in balance_layers: @@ -215,3 +204,29 @@ def _apply_smoothing(self): smooth_layer.weight.div_(scales.view(-1, 1)) if hasattr(smooth_layer, "bias"): smooth_layer.bias.div_(scales) + + def _calculate_smoothing_scales( + self, balance_layers: List[Module], activation_scales: torch.Tensor + ) -> List[float]: + """ + Calculate how much smoothing to apply to each channel based on the dynamic + range of the activation and the following weights + + :param balance_layers: layers to offset activation smoothing to + :param activation_scales: channel-wise dynamic range of activation to smooth + :return: channel-wise scales to use for smoothing activation + """ + # get the channel-wise dynamic range for each layer to be balanced + weight_scales = [] + for layer in balance_layers: + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) + weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] + + # calculate the amount of smoothing to apply + # s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha) + # where j is the input channel, alpha is smoothing strength + scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow( + 1 - self.smoothing_strength + ) + return scales diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index bb25e27a69a..6594bf39547 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,7 +1,7 @@ test_stage: obcq_modifiers: SmoothQuantModifier: - migration_strength: 0.5 + smoothing_strength: 0.5 mappings: [ [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], [["re:.*fc1"], "re:.*final_layer_norm"] diff --git a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml index b43cc4512c7..ea3f4ae5cd1 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml @@ -1,7 +1,7 @@ test_stage: obcq_modifiers: SmoothQuantModifier: - migration_strength: 0.5 + smoothing_strength: 0.5 mappings: [ [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]