Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPTQ UX config groups support #2273

Merged
merged 7 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 49 additions & 7 deletions src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import logging
from typing import Any, Dict, List, Optional, Union

from pydantic import Field

from compressed_tensors.quantization import QuantizationScheme
from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.model.base import ModifiableModel
Expand Down Expand Up @@ -53,13 +56,29 @@ class GPTQModifier(Modifier):
already exist in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param config_groups: [Used, if a quantization modifier is not specified],
dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: [Used, if a quantization modifier is not specified]
optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: [Used, if a quantization modifier is
not specified] Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
block_size: int = 128
quantize: Union[bool, Dict] = True
dampening_frac: Optional[float] = 0.01
config_groups: Optional[Dict[str, QuantizationScheme]] = None
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None
compressible_layers_: Optional[List] = None
quantization_modifier_: Any = None

Expand All @@ -81,14 +100,10 @@ def on_initialize_structure(self, state: State, **kwargs):
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
"GPTQ quantization is set to True without an "
"active quantization modifier."
)
self._build_quant_modifier(state.framework)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
Expand Down Expand Up @@ -118,6 +133,33 @@ def on_initialize_structure(self, state: State, **kwargs):
if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier(self, framework):
"""
Build a quantization modifier based on the specified config_groups,
ignore list, and num_calibration_steps.

:postcondition: self.quantization_modifier_ is set to the built
quantization modifier
:param framework: the framework to build the quantization modifier for
"""

quantization_args_names = [
"config_groups",
"num_calibration_steps",
"ignore",
"disable_quantization_observer_epoch",
]

quant_args = {
key: getattr(self, key)
for key in quantization_args_names
if getattr(self, key, False)
}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"vLLMQuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)


def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
Expand Down
27 changes: 20 additions & 7 deletions src/sparseml/modifiers/quantization/gptq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,25 @@ class GPTQModifierPyTorch(GPTQModifier):
- LayerCompressor.post_compress()
- LayerCompressor.revert_layer_wrappers()
| Sample yaml:
| test_stage:
| obcq_modifiers:
| GPTQModifier:
| sequential_update: True
| dampening_frac: 0.001
| block_size: 128
| test_stage:
| obcq_modifiers:
| GPTQModifier:
| sequential_update: True
| dampening_frac: 0.001
| block_size: 128
| config_groups:
| group_0:
| targets:
| - "Linear"
| input_activations: null
| output_activations: null
| weights:
| num_bits: 8
| type: "int"
| symmetric: true
| strategy: "tensor"
| group_size: 128


:param model: Pytorch model to perform GPTQ on, in place.
"""
Expand All @@ -59,7 +72,7 @@ class GPTQModifierPyTorch(GPTQModifier):

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
Initialize and run the GPTQ algorithm on the current state

:param state: session state storing input model and calibration data
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cadence: "commit"
test_type: "sanity"
model: "Xenova/llama2.c-stories15M"
dataset: open_platypus
initial_pruning_only_recipe: "tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml"
initial_sparsity: 0.5
recipe_mask_structure: "2:4"
subsequent_prune_and_quant_recipe: "tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml"
final_sparsity: 0.7
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
test_stage:
obcq_modifiers:
SmoothQuantModifier:
smoothing_strength: 0.5
mappings: [
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
QuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLU
- model.layers.0.mlp.down_proj
- model.layers.1.mlp.down_proj
- model.layers.2.mlp.down_proj
- model.layers.3.mlp.down_proj
- model.layers.4.mlp.down_proj
- model.layers.5.mlp.down_proj
post_oneshot_calibration: True
scheme_overrides:
Embedding:
input_activations: null
weights:
num_bits: 8
symmetric: False
SparseGPTModifier:
sparsity: 0.7
block_size: 128
sequential_update: False
percdamp: 0.01
mask_structure: "0:0"
targets: [
"model.layers.0",
]
preserve_sparsity_mask: True
GPTQModifier:
sequential_update: False
dampening_frac: 0.01
targets: [
"model.layers.0",
]
block_size: 128
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_stage:
obcq_modifiers:
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
percdamp: 0.01
mask_structure: "2:4"
targets: [
"model.layers.0",
]
148 changes: 148 additions & 0 deletions tests/sparseml/transformers/obcq/test_mask_structure_preservation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from pathlib import Path

import pytest

import sparseml
from parameterized import parameterized_class
from tests.testing_utils import parse_params, requires_torch


MASK_STRUCTURE_CONFIGS_DIRECTORY = (
"tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure"
)


def tensor_follows_mask_structure(tensor, mask: str = "2:4"):
"""
:param tensor: tensor to check
:param mask: mask structure to check for, in the format "n:m"
:return: True if the tensor follows the mask structure, False otherwise.
Note, some weights can incidentally be zero, so we check for
atleast n zeros in each chunk of size m
"""
import torch

n, m = tuple(map(int, mask.split(":")))
# Reshape the tensor into chunks of size m
tensor = tensor.view(-1, m)

# Count the number of zeros in each chunk
zero_counts = (tensor == 0).sum(dim=1)

# Check if the number of zeros in each chunk atleast n
# Greater than sign is needed as some weights can incidentally
# be zero
return torch.all(zero_counts >= n)


@requires_torch
@pytest.mark.integration
@parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY))
class TestMaskStructurePreserved(unittest.TestCase):
"""
Tests that the mask structure is preserved across multiple runs of oneshot
initial model is pruned using a mask_structure, and then the pruned model
is further pruned and quantized.
"""

model = None
initial_pruning_only_recipe = None
initial_sparsity = None
recipe_mask_structure = None
dataset = None
subsequent_prune_and_quant_recipe = None
final_sparsity = None

def setUp(self) -> None:
import torch

self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.output = "./oneshot_output"
self.output_first = Path(self.output) / "test_1"
self.output_second = Path(self.output) / "test_2"

def test_mask_structure_preserved(self):
"""
Checks that the mask structure is preserved across runs of oneshot
between the initial pruning and the subsequent pruning + quantization
"""
import math

import torch

from sparseml.pytorch.model_load.helpers import get_session_model
from sparseml.pytorch.utils.helpers import tensor_sparsity
from sparseml.transformers import oneshot
from sparseml.utils.pytorch import qat_active

tolerance = 1e-3
num_calibration_samples = 16

oneshot(
model=self.model,
dataset=self.dataset,
num_calibration_samples=num_calibration_samples,
recipe=self.initial_pruning_only_recipe,
output_dir=self.output_first,
oneshot_device=self.device,
clear_sparse_session=False,
)
first_tiny_model = get_session_model()
targetted_layer = first_tiny_model.model.layers[0].self_attn.k_proj
target_layer_sparsity = tensor_sparsity(targetted_layer.weight)
initial_mask = first_tiny_model.model.layers[0].self_attn.k_proj.weight == 0

# sparsity is as expected, i.e close to self.initial_sparsity
assert math.isclose(
target_layer_sparsity.item(), self.initial_sparsity, rel_tol=tolerance
)
# mask structure is as expected, i.e same as self.recipe_mask_structure
assert tensor_follows_mask_structure(initial_mask, self.recipe_mask_structure)

sparseml.reset_session()

oneshot(
model=self.output_first,
dataset=self.dataset,
num_calibration_samples=num_calibration_samples,
recipe=self.subsequent_prune_and_quant_recipe,
output_dir=self.output_second,
oneshot_device=self.device,
clear_sparse_session=False,
)

second_tiny_model = get_session_model()

# model is loaded
assert second_tiny_model is not None

targetted_layer = second_tiny_model.model.layers[0].self_attn.k_proj.module
target_layer_sparsity = tensor_sparsity(targetted_layer.weight)

# sparsity is as expected, i.e close to self.final_sparsity
assert math.isclose(
target_layer_sparsity.item(), self.final_sparsity, rel_tol=tolerance
)
# qat should be active, second recipe has quantization
assert qat_active(second_tiny_model)

# original mask structure is preserved, additional zeros are
# added on top of the initial mask
final_mask = targetted_layer.weight == 0
assert torch.all(initial_mask <= final_mask)