Skip to content

Commit

Permalink
Channelwise Quantization Tests (#2283)
Browse files Browse the repository at this point in the history
* WIP

* WIP2

* fix for channelwise

* set for nightly

* more calibration samples

* run 15m tests on commit

* update config

* mark as integration

* fixes for tests

* cleanup class
  • Loading branch information
Sara Adkins authored May 28, 2024
1 parent c5ac841 commit c530bbf
Show file tree
Hide file tree
Showing 14 changed files with 160 additions and 72 deletions.
1 change: 0 additions & 1 deletion src/sparseml/modifiers/quantization/gptq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def _pruning_arguments(self):
"""
Gather the parameters needed for root module compression in a dict
:param sparsity: target sparsity
:return: dict of params for pruning
"""
return {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_channel.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_channel.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_full.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_full.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "nightly"
test_type: "regression"
model_stub: "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
cadence: "commit"
test_type: "regression"
model_stub: "Xenova/llama2.c-stories15M"
old_recipe: "tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml"
new_recipe: "tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml"
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
test_stage:
quant_modifiers:
vLLMQuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: False
strategy: "channel"
input_activations: null
output_activations: null
targets: ["Linear"]
GPTQModifier:
block_size: 128
sequential_update: False
targets: ["model.layers.0", "model.layers.1", "model.layers.2"]
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,4 @@ test_stage:
GPTQModifier:
block_size: 128
sequential_update: False
percdamp: 0.01
targets: ["re:model.layers.\\d+$"]
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ test_stage:
GPTQModifier:
block_size: 128
sequential_update: False
percdamp: 0.01
targets: ["re:model.layers.\\d+$"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
test_stage:
quant_modifiers:
QuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- lm_head
- LlamaRotaryEmbedding
- LlamaRMSNorm
- SiLU
- MatMulLeftInput_QK
- MatMulRightInput_QK
- MatMulOutput_QK
- MatMulLeftInput_PV
- MatMulRightInput_PV
- MatMulOutput_PV
- Embedding
scheme_overrides:
Linear:
weights:
num_bits: 4
symmetric: false
strategy: "channel"
input_activations: null
output_activations: null
GPTQModifier:
block_size: 128
sequential_update: False
targets: ["model.layers.0", "model.layers.1", "model.layers.2"]
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ test_stage:
GPTQModifier:
block_size: 128
sequential_update: False
percdamp: 0.01
targets: ["re:model.layers.\\d+$"]
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,4 @@ test_stage:
GPTQModifier:
block_size: 128
sequential_update: False
percdamp: 0.01
targets: ["re:model.layers.\\d+$"]
151 changes: 84 additions & 67 deletions tests/sparseml/transformers/compression/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import os
import shutil
import tempfile
import unittest

import pytest
import torch
from torch.utils.data import DataLoader
from transformers import DefaultDataCollator

from compressed_tensors.quantization import fake_quantize
from compressed_tensors.quantization.utils import is_module_quantized
from parameterized import parameterized_class
from sparseml.pytorch.utils import tensors_to_device
Expand All @@ -32,30 +33,31 @@
)
from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
from tests.testing_utils import requires_gpu, requires_torch
from tests.testing_utils import parse_params, requires_gpu, requires_torch


CONFIGS_DIRECTORY = "tests/sparseml/transformers/compression/configs"


@requires_torch
@requires_gpu
@parameterized_class(
("old_recipe", "new_recipe"),
[
(
"tests/sparseml/transformers/compression/recipes/old_quant_full.yaml",
"tests/sparseml/transformers/compression/recipes/new_quant_full.yaml",
),
(
"tests/sparseml/transformers/compression/recipes/old_quant_weight.yaml",
"tests/sparseml/transformers/compression/recipes/new_quant_weight.yaml",
),
],
)
@pytest.mark.integration
@parameterized_class(parse_params(CONFIGS_DIRECTORY))
class TestQuantizationMatches(unittest.TestCase):
"""
Tests new compressed-tensors quantization format matches performance with the old
sparseml format. For setup, this class runs a full oneshot run with both an old and
new quantization recipe that should be equivalent. Then tests the following:
- quantization structure matches after oneshot
- quantized weights match
- decompressing the new model has the expected weights on reload
- no perplexity regression from the old quantization framework, asserts we are
no more than 2% on perplexity
"""

old_recipe = None
new_recipe = None
# TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly
# or weekly runs, but this smaller model is better for commit testing
model_stub = "Xenova/llama2.c-stories15M"
model_stub = None
dataset = "open_platypus"
old_output = "tiny_llama_old"
new_output = "tiny_llama_new"
Expand Down Expand Up @@ -86,16 +88,9 @@ def setUpClass(cls):
os.path.join(cls.test_dir, cls.new_output),
)

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.test_dir)
del cls.model_new
del cls.model_old
torch.cuda.empty_cache()

@staticmethod
def _run_oneshot(model, recipe, dataset, output_dir):
num_calibration_samples = 512
num_calibration_samples = 256
max_seq_length = 512
pad_to_max_length = False

Expand All @@ -116,12 +111,13 @@ def _get_quant_info_old(self, model):
quant_info_inputs = {}
for name, module in model.named_modules():
if hasattr(module, "weight_fake_quant"):
scale = module.weight_fake_quant.scale.item()
zp = module.weight_fake_quant.zero_point.item()
quant_info_weights[name] = (scale, zp)
scale = module.weight_fake_quant.scale
zp = module.weight_fake_quant.zero_point
weight = module.weight_fake_quant(module.weight)
quant_info_weights[name] = (scale, zp, weight)
elif hasattr(module, "quant"):
scale = module.quant.activation_post_process.scale.item()
zp = module.quant.activation_post_process.zero_point.item()
scale = module.quant.activation_post_process.scale
zp = module.quant.activation_post_process.zero_point
quant_info_inputs[name] = (scale, zp)

return quant_info_weights, quant_info_inputs
Expand All @@ -133,34 +129,67 @@ def _get_quant_info_new(self, model):
if is_module_quantized(module):
if module.quantization_scheme.weights is not None:
quant_info_weights[name] = (
module.weight_scale.item(),
module.weight_zero_point.item(),
module.weight_scale,
module.weight_zero_point,
fake_quantize(
module.weight,
module.weight_scale,
module.weight_zero_point,
module.quantization_scheme.weights,
),
)
if module.quantization_scheme.input_activations is not None:
quant_info_inputs[name] = (
module.input_scale.item(),
module.input_zero_point.item(),
module.input_scale,
module.input_zero_point,
)

return quant_info_weights, quant_info_inputs

def _get_dataloader(self, data_args, tokenizer):
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train",
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset,
batch_size=1,
collate_fn=DefaultDataCollator(),
sampler=torch.utils.data.RandomSampler(calib_dataset),
)

return data_loader

def test_quantization_counts(self):
old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old)
new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new)

assert len(old_quant_weights) == len(new_quant_weights)
assert len(old_quant_inputs) == len(new_quant_inputs)

def test_quantization_scale_and_zp(self):
old_quant_weights, old_quant_inputs = self._get_quant_info_old(self.model_old)
new_quant_weights, new_quant_inputs = self._get_quant_info_new(self.model_new)
def test_quantization_matches(self):
old_quant_weights, _ = self._get_quant_info_old(self.model_old)
new_quant_weights, _ = self._get_quant_info_new(self.model_new)

for name, (o_scale, o_zp) in old_quant_weights.items():
for name, (o_scale, o_zp, _) in old_quant_weights.items():
if name.endswith(".module"):
name = name[:-7]
n_scale, n_zp = new_quant_weights[name]
assert math.isclose(o_scale, n_scale, abs_tol=1e-3, rel_tol=1e-3)
assert o_zp == n_zp
n_scale, n_zp, _ = new_quant_weights[name]
if n_scale.ndim == 2: # channelwise
n_scale = n_scale[:, 0]
n_zp = n_zp[:, 0]
elif n_scale.ndim == 0: # tensor
n_scale = torch.unsqueeze(n_scale, 0)
n_zp = torch.unsqueeze(n_zp, 0)

assert torch.all(
torch.isclose(o_scale.cpu(), n_scale.cpu(), atol=1e-3, rtol=1e-3)
)

def test_quantization_reload(self):
model_reloaded = SparseAutoModelForCausalLM.from_pretrained(
Expand All @@ -170,34 +199,15 @@ def test_quantization_reload(self):
og_weights, og_inputs = self._get_quant_info_new(self.model_new)
reloaded_weights, reloaded_inputs = self._get_quant_info_new(model_reloaded)

for name, (o_scale, o_zp) in og_weights.items():
n_scale, n_zp = reloaded_weights[name]
assert o_scale == n_scale
assert o_zp == n_zp
for name, (o_scale, o_zp, _) in og_weights.items():
n_scale, n_zp, _ = reloaded_weights[name]
assert torch.equal(o_scale.cpu(), n_scale.cpu())
assert torch.equal(o_zp.cpu(), n_zp.cpu())

for name, (o_scale, o_zp) in og_inputs.items():
n_scale, n_zp = reloaded_inputs[name]
assert o_scale == n_scale
assert o_zp == n_zp

def _get_dataloader(self, data_args, tokenizer):
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
split="train",
tokenizer=tokenizer,
)
calib_dataset = dataset_manager.tokenize_and_process(
dataset_manager.get_raw_dataset()
)
data_loader = DataLoader(
calib_dataset,
batch_size=1,
collate_fn=DefaultDataCollator(),
sampler=torch.utils.data.RandomSampler(calib_dataset),
)

return data_loader
assert torch.equal(o_scale.cpu(), n_scale.cpu())
assert torch.equal(o_zp.cpu(), n_zp.cpu())

@torch.no_grad()
def test_perplexity(self):
Expand Down Expand Up @@ -228,3 +238,10 @@ def test_perplexity(self):
total_ppl_old / total_non_nan
)
assert avg_ppl_ratio <= 1.02

@classmethod
def tearDownClass(cls):
shutil.rmtree(cls.test_dir)
del cls.model_new
del cls.model_old
torch.cuda.empty_cache()

0 comments on commit c530bbf

Please sign in to comment.