Skip to content

Commit

Permalink
PR comments on logarithmic equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Satrat committed Oct 20, 2023
1 parent 45e3630 commit 6b0393d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 18 deletions.
4 changes: 1 addition & 3 deletions src/sparseml/modifiers/smoothquant/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,13 @@ class SmoothQuantModifier(Modifier):
:param mappings: list activation layers to smooth, and the which layers to offset
the smoothing to for each activation
:param ignore: list of layers to ignore, even if they match a regex in mappings
:param logarithmic_equalization: Whether to use a logarithmic scale for smoothing
:param num_calibration_steps: number of samples to use for calibration, or None to
use the whole dataset
"""

smoothing_strength: float = Field(..., alias="alpha")
smoothing_strength: float = Field(..., validation_alias="alpha")
mappings: List[Tuple]
ignore: Optional[List[str]] = None
logarithmic_equalization: Optional[bool] = False
num_calibration_steps: Optional[int] = None

resolved_mappings_: Dict = None
Expand Down
41 changes: 28 additions & 13 deletions src/sparseml/modifiers/smoothquant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Callable, List, Optional

import torch
from torch.nn import Module

from sparseml.core import State
from sparseml.core.model.pytorch import ModifiableModelPyTorch
Expand Down Expand Up @@ -190,19 +191,7 @@ def _apply_smoothing(self):
smooth_layer = mapping.smooth_layer
balance_layers = mapping.balance_layers

# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
# s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha)
# where j is the input channel, alpha is smoothing strength
scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow(
1 - self.smoothing_strength
)
scales = self._calculate_smoothing_scales(balance_layers, activation_scales)

# invert the smoothing in the following layers
for layer in balance_layers:
Expand All @@ -215,3 +204,29 @@ def _apply_smoothing(self):
smooth_layer.weight.div_(scales.view(-1, 1))
if hasattr(smooth_layer, "bias"):
smooth_layer.bias.div_(scales)

def _calculate_smoothing_scales(
self, balance_layers: List[Module], activation_scales: torch.Tensor
) -> List[float]:
"""
Calculate how much smoothing to apply to each channel based on the dynamic
range of the activation and the following weights
:param balance_layers: layers to offset activation smoothing to
:param activation_scales: channel-wise dynamic range of activation to smooth
:return: channel-wise scales to use for smoothing activation
"""
# get the channel-wise dynamic range for each layer to be balanced
weight_scales = []
for layer in balance_layers:
scale = layer.weight.abs().max(dim=0, keepdim=True)[0]
weight_scales.append(scale)
weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0]

# calculate the amount of smoothing to apply
# s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha)
# where j is the input channel, alpha is smoothing strength
scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow(
1 - self.smoothing_strength
)
return scales
2 changes: 1 addition & 1 deletion src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
test_stage:
obcq_modifiers:
SmoothQuantModifier:
migration_strength: 0.5
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"],
[["re:.*fc1"], "re:.*final_layer_norm"]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
test_stage:
obcq_modifiers:
SmoothQuantModifier:
migration_strength: 0.5
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
Expand Down

0 comments on commit 6b0393d

Please sign in to comment.